Pytorch---使用Pytorch实现U-Net进行语义分割

2023-05-16

一、代码中的数据集可以通过以下链接获取

百度网盘提取码:f1j7

二、代码运行环境

Pytorch-gpu==1.10.1
Python==3.8

三、数据集处理代码如下所示

import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks


class MaskDataset(data.Dataset):
    def __init__(self, image_paths, mask_paths, transform):
        super(MaskDataset, self).__init__()
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label_path = self.mask_paths[index]

        pil_img = Image.open(image_path)
        pil_img = pil_img.convert('RGB')
        img_tensor = self.transform(pil_img)

        pil_label = Image.open(label_path)
        label_tensor = self.transform(pil_label)
        label_tensor[label_tensor > 0] = 1
        label_tensor = torch.squeeze(input=label_tensor).type(torch.long)

        return img_tensor, label_tensor

    def __len__(self):
        return len(self.mask_paths)


def load_data():
    # DATASET_PATH = r'/home/akita/hk'
    DATASET_PATH = r'/Users/leeakita/Desktop/hk'
    TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
    TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')

    train_file_names = os.listdir(TRAIN_DATASET_PATH)
    test_file_names = os.listdir(TEST_DATASET_PATH)

    train_image_names = [name for name in train_file_names if
                         'matte' in name and name.split('_')[0] + '.png' in train_file_names]
    train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
                         train_image_names]
    train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]

    test_image_names = [name for name in test_file_names if
                        'matte' in name and name.split('_')[0] + '.png' in test_file_names]
    test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
    test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    BATCH_SIZE = 8

    train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
    test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl


if __name__ == '__main__':
    train_my, test_my = load_data()
    images, labels = next(iter(train_my))
    index = 5
    images = images[index]
    labels = labels[index]
    labels = torch.unsqueeze(input=labels, dim=0)

    result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=labels, dtype=torch.bool),
                                     alpha=0.6, colors=['red'])
    plt.imshow(result.permute(1, 2, 0).numpy())
    plt.show()

四、模型的构建代码如下所示

from torch import nn
import torch


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pool = nn.MaxPool2d(kernel_size=2)

    def forward(self, x, is_pool=True):
        if is_pool:
            x = self.pool(x)
        x = self.conv_relu(x)
        return x


class UpSample(nn.Module):
    def __init__(self, channels):
        super(UpSample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(in_channels=2 * channels, out_channels=channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.up_conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=channels, out_channels=channels // 2, kernel_size=3, stride=2,
                               output_padding=1, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_relu(x)
        x = self.up_conv(x)
        return x


class UnetModel(nn.Module):
    def __init__(self):
        super(UnetModel, self).__init__()
        self.down_1 = DownSample(in_channels=3, out_channels=64)
        self.down_2 = DownSample(in_channels=64, out_channels=128)
        self.down_3 = DownSample(in_channels=128, out_channels=256)
        self.down_4 = DownSample(in_channels=256, out_channels=512)
        self.down_5 = DownSample(in_channels=512, out_channels=1024)

        self.up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, output_padding=1,
                               padding=1),
            nn.ReLU()
        )
        self.up_1 = UpSample(channels=512)
        self.up_2 = UpSample(channels=256)
        self.up_3 = UpSample(channels=128)

        self.conv_2 = DownSample(in_channels=128, out_channels=64)
        self.last = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    def forward(self, x):
        down_1 = self.down_1(x, is_pool=False)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)
        down_5 = self.down_5(down_4)

        down_5 = self.up(down_5)

        down_5 = torch.cat([down_4, down_5], dim=1)
        down_5 = self.up_1(down_5)

        down_5 = torch.cat([down_3, down_5], dim=1)
        down_5 = self.up_2(down_5)

        down_5 = torch.cat([down_2, down_5], dim=1)
        down_5 = self.up_3(down_5)

        down_5 = torch.cat([down_1, down_5], dim=1)

        down_5 = self.conv_2(down_5, is_pool=False)

        down_5 = self.last(down_5)

        return down_5

五、模型的训练代码如下所示

import torch
from data_loader import load_data
from model_loader import UnetModel
from torch import nn
from torch import optim
import tqdm
import os

# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载数据
train_dl, test_dl = load_data()

# 加载模型
model = UnetModel()
model = model.to(device=device)

# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)

# 开始进行训练
for epoch in range(100):
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
    train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
    for train_images, train_labels in train_tqdm:
        train_images, train_labels = train_images.to(device), train_labels.to(device)
        pred = model(train_images)
        loss = loss_fn(pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
            train_tqdm.set_postfix({'train loss': train_loss_sum.mean().item()})
    train_tqdm.close()

    lr_scheduler.step()

    with torch.no_grad():
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
        test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
        for test_images, test_labels in test_tqdm:
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            test_pred = model(test_images)
            test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)
            test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
            test_tqdm.set_postfix({'test loss': test_loss_sum.mean().item()})
        test_tqdm.close()

# 模型的保存
if not os.path.exists(os.path.join('model_data')):
    os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

六、模型的预测代码如下所示

import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import UnetModel

# 数据的加载
train_dl, test_dl = load_data()

# 模型的加载
model = UnetModel()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)

# 开始进行预测
images, labels = next(iter(test_dl))
index = 1
with torch.no_grad():
    pred = model(images)
    pred = torch.argmax(input=pred, dim=1)
    result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
                                     alpha=0.6, colors=['red'])
    plt.figure(figsize=(8, 8), dpi=500)
    plt.axis('off')
    plt.imshow(result.permute(1, 2, 0))
    plt.savefig('result.png')
    plt.show()

七、代码的运行结果如下所示

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch---使用Pytorch实现U-Net进行语义分割 的相关文章

  • 有没有办法设置 log4net 内存附加程序可以包含的最大错误消息数?

    我想向根记录器添加一个内存附加程序 以便我可以连接到应用程序并获取最后 10 个事件 我只想保留最后 10 个 我担心这个附加程序会消耗太多内存 该应用程序设计为 24 7 运行 或者还有别的办法吗 您需要创建一个自定义附加程序来存储有限数
  • 如何将 CSV 文件读入 .NET 数据表

    如何将 CSV 文件加载到System Data DataTable 根据CSV文件创建数据表 常规 ADO net 功能是否允许这样做 我一直在使用OleDb提供者 但是 如果您正在读取具有数值的行 但希望将它们视为文本 则会出现问题 但
  • 元素属性语法和属性属性语法之间有语义差异吗?

    我认为元素属性语法和属性属性语法在语义上没有太大区别 但是 我发现一定有什么不同 例如 下面的例子只是演示了一个简单的触发器
  • 设置 runat=server 时输入名称和 id 发生变化

    在我的表单中 我需要插入 文本 类型的不同输入 输入必须是带有名称和 ID 的 html 控件 因为我将此表单发送到外部网址 对于验证 我在所有输入中执行 runat server 然后我可以使用 requiredfieldvalidato
  • 强制 Mpeg2 解复用器使用 ffdshow 渲染 H 264 数字电视视频

    不幸的是 我花了很多时间尝试使 DirectShow 的 DTVViewer 示例正常工作 但没有成功 DVBT网络的视频格式是H264 我发现IntelliConnect行为IFilterGraph更喜欢使用 Mpeg2 视频格式 对于那
  • ASP.NET MVC ActionFilterAttribute 在模型绑定之前注入值

    我想创建一个自定义操作过滤器属性 该属性在模型绑定期间可访问的 HttpContext 项中添加一个值 我尝试将其添加到 OnActionExecuting 中 但似乎模型绑定是在过滤器之前执行的 你知道我该怎么做吗 也许模型绑定器中有一个
  • C#:询问用户密码,然后将其存储在 SecureString 中

    在我目前为客户开发的小型应用程序中 我需要询问用户他的 Windows 登录用户名 密码和域 然后使用这些信息系统 诊断 进程 启动启动一个应用程序 我有一个带有 UseSystemPasswordChar 的文本框来屏蔽输入的密码 我需要
  • string.Empty 与 null。您使用哪一个?

    最近工作的同事告诉我不要使用string Empty设置字符串变量时但使用null因为它污染了堆栈 他说不做 string myString string Empty but do string mystring null 真的有关系吗 我
  • 如何获取可用系统内存的大小?

    C NET 中是否可以获取系统可用内存的大小 如果是的话怎么办 Use Microsoft VisualBasic Devices ComputerInfo TotalPhysicalMemory http msdn microsoft c
  • 如何在 .NET Framework 2.0 中模拟“Func<(Of <(TResult>)>) 委托”?

    我尝试使用这个类代码项目文章 http www codeproject com KB threads AsyncVar aspx在 VB NET 和 NET Framework 2 0 中 除了这一行之外 所有内容似乎都可以编译Privat
  • Directory.Delete 之后 Directory.Exists 有时返回 true ?

    我有非常奇怪的行为 我有 Directory Delete tempFolder true if Directory Exists tempFolder 有时 Directory Exists 返回 true 为什么 可能是资源管理器打开了
  • 获取两个工作日之间的天数差异

    这听起来很简单 但我不明白其中的意义 那么获取两次之间的天数的最简单方法是什么DayOfWeeks当第一个是起点时 如果下一个工作日较早 则应考虑在下周 The DayOfWeek 枚举 http 20 20 5B1 5D 3a 20htt
  • 从 Excel 应用程序对象中查找位数(32 位/64 位)?

    是否可以从 Microsoft Office Interop Excel ApplicationClass 确定 Excel 是以 32 位还是 64 位运行 Edit该解决方案应该适用于 Excel 2010 和 Excel 2007 此
  • 如何使用命令行压缩指定文件夹

    你们能告诉我如何将指定的文件压缩到同一个 Zip 文件中吗 让我告诉我我的文件夹是如何填充的 任务调度程序有我的数据库的备份 并每天将它们保存到文件中 它每天创建 4 个数据库备份 这意味着每天会多出 4 个文件 因此 我需要将新创建的备份
  • 如何将位写入文件?

    如何使用 c net 将位 而不是字节 写入文件 我很坚持它 Edit 我正在寻找一种不同的方法 将每 8 位写为一个字节 一次可以写入的最小数据量是一个字节 如果您需要写入单独的位值 例如 二进制格式需要 1 位标志 3 位整数和 4 位
  • 类库的 app.config 中的绑定重定向有什么作用吗?

    我经常使用的 VS 解决方案包括单个可执行项目 控制台应用程序 网络应用程序 和许多类库项目这些都被可执行文件引用 使用 NuGet 并安装包时 经常会出现app config为每个项目创建的文件 通常只包含合并引用程序集版本的绑定重定向列
  • 如何在 .NET 中使 ComboBox 不可编辑?

    我想要一个 仅选择 ComboBox它提供了一个项目列表供用户选择 应在文本部分禁用打字ComboBox控制 我最初对此进行谷歌搜索 发现了一个过于复杂 误导性的建议来捕捉KeyPress event 要使 ComboBox 的文本部分不可
  • 避免使用一本字典的更好代码 - 区分大小写问题

    我有以下方法用数据读取器的值填充字典 数据读取器字段和传递给方法的属性之间可能存在大小写不匹配的情况 在下面的方法中 我首先将属性转换为小写以解决此问题 这会导致两个字典 有没有更好的方法用一本字典来实现这一目标 private Dicti
  • 为什么我的程序集在安装到 GAC 后在“添加引用 > .Net”中不可见?

    我想问一个关于 GAC 的简单问题我创建了一个程序集 Awesome DLL 对其进行签名 然后将其安装到 GAC 中 C MyApps Awesome Awesome Awesome bin Release gt sn k Awesome
  • Web 和 winforms 的 .Net 身份验证

    我有一个为客户端构建的 ASP NET Web 应用程序 它使用默认的 ASP NET 表单身份验证 他们现在请求一个能够 与 Web 应用程序一起工作的桌面 WinForms 应用程序 我已经创建了 Web 服务来访问他们想要从 Web

随机推荐

  • springboot启动错误Could not resolve placeholder ‘XXX‘ in value “${XXX}“

    百度了很多方法 xff0c 都没有解决 xff0c 记录一下 后来发现是因为多模块项目中 xff0c 必须要在有XXXXApplication java主启动类的项目下的application yml中配置的参数才可以读取到 xff0c 之
  • Map中key和value值是否可以为null或空字符串?

    答案 xff1a HashMap既支持分别为空 null xff0c 也支持key和value同时为空 null Hashtable不支持key和value存储null xff0c 但支持存空字符串 HashMap HashMap是中支持空
  • CentOS 7+vim+ycm(clang)

    原本想在CentOS6 6下搞 xff0c 中间各种问题 xff0c 要升级Python xff0c 要升级gcc xff08 还因为之前系统没划分swap分区 xff0c 高版本的gcc编译不出来 xff09 xff0c 后来索性放弃 x
  • 实现map按输入顺序输出或按key排序

    HashMap输出时是无序的 想要顺序输出就要借助其他map HashMap缺点 xff1a HashMap是非线程安全的 xff0c 多个线程同时写入可能导致数据不一致 解决办法详见 xff1a HashMap是非线程安全的解决办法 li
  • hashMap是线程安全的吗?若不是,有什么线程安全的解决方法?

    HashMap是非线程安全的 多个线程同时写入可能导致数据不一致 xff0c 从而出现各种脏数据 想要实现线程安全的解决方法 HashTable 是线程安全的 HashTable 容器使用 synchronized 来保证线程安全 xff0
  • Codeblocks快捷键合集

    索引 CodeBlocks常用操作快捷键编辑部分 xff1a 编译与运行部分 xff1a 调试部分 xff1a 界面部分 xff1a CodeBlocks常用操作快捷键 编辑部分 xff1a Ctrl 43 A xff1a 全选bai Ct
  • Ubuntu 18.04 Linux内核升级(因为在系统中安装会出现各种驱动不兼容的问题,所以去官网下)

    前言 原本使用Ubuntu 18 04 2 LTS来换到5 4 45的内核版本 xff0c 来尝试在系统中直接用sudo apt get install linux image 命令更新一下Linux的内核 xff0c 但是碰到各种驱动不兼
  • ubuntu xubuntu 安装xrdp 键盘鼠标无法输入问题 采用命令行解决办法

    前言 原本打算安装xrdp实现windows控制ubuntu的 xff0c 结果安装完成后 xff0c 系统一重启突然发现键盘鼠标不能用了 xff0c 去网络上搜索了很多解决办法发现都不行 xff0c 后来经过不断地尝试 xff0c 终于找
  • 云服务器ECS入门

    1 什么云服务器ECS 云服务器ECS xff08 Elastic Compute Service xff09 是阿里云提供的性能卓越 稳定可靠 弹性扩展的IaaS xff08 Infrastructure as a Service xff
  • 启用微软e5子账户的outlook邮箱,解决 qyi 续订程序无法刷新令牌问题

    使用 qyi io 提供的 e5 子账户续订服务 xff0c 如果子账户的outlook未启用 xff0c 则会报错 xff1a 无法刷新令牌 code 2 错误消息 The mailbox is either inactive soft
  • Python基础语句

    一 判断语句 在程序中如果某些条件满足 xff0c 才能做某件事情 xff0c 而不满足时不允许做 xff0c 这就是所谓的判断 1 if语句的使用格式 if 条件 条件成立时 要做的事 案例 判断年纪 xff0c 如果age大于18 xf
  • Keil5程序编译下载不能正常运行,在线调试却正常工作

    Keil5程序编译下载不能正常运行 xff0c 在线调试正常工作 本人水平有限以下个人经验仅供参考 xff0c 很可能有各种错误和不准确之处 在Keil5的使用中 xff0c 本人之前弄出了一个位置错误导致了似乎是栈溢出问题 xff0c 程
  • 家用动态IP配置DDNS

    文章目录 动态公网IP配置DDNS申请域名API访问密钥编写Java代码 xff0c 定时更新域名IPDDNS程序打包上传服务器 xff0c 配置为开机自启动开源代码的使用说明 动态公网IP配置DDNS 家里有台老旧的笔记本电脑闲置着 xf
  • CentOS 7 安装 netmap

    0 环境准备 操作系统 xff1a CentOS 7 3 1611 yum y install rpm build redhat rpm config asciidoc bison hmaccalc patchutils perl ExtU
  • Linux下Nvidia驱动的安装

    1 查看Linux系统是否已经安装了Nvidia驱动 命令行输入 xff1a nvidia smi 进行查看 xff1a nvidia smi 如果输出以下信息 xff0c 则Linux系统中已经安装了Nvidia驱动 如果没有以上的输出信
  • Linux下安装cuda和对应版本的cudnn

    1 首先在安装cuda与cudnn之前 xff0c 系统需要成功安装Nvidia驱动 xff0c 安装教程请参照以下教程 xff1a Nvidia驱动安装教程 2 验证系统内部是否已经安装了cuda 打开命令行 xff0c 输入以下命令 x
  • Windows安装Matlab的具体步骤

    一 进行相关文件的下载 安装所需的文件可以从以下链接获取 xff0c 提取码 xff1a 5417 点击此处进行文件下载 二 进行Matlab的安装 1 文件内容解释 下载的文件一共有三个 xff0c 文件名称分别为 xff1a Matla
  • Tensorflow---Tensorflow的五种保存模型的方式介绍

    一 保存模型的全部配置信息 使用model save 函数搭配tf keras models load model 对模型的架构 xff0c 权重以及配置进行保存与恢复 模型的保存代码如下 xff1a span class token ke
  • ISPRS遥感数据集--Potsdam数据集,Vaihingen数据集,Toronto数据集

    一 数据的获取 Potsdam数据集下载链接 百度网盘提取码 xff1a lala Vaihingen数据集下载链接 百度网盘提取码 xff1a lala Toronto数据集下载链接 百度网盘提取码 xff1a lala 二 数据集的介绍
  • Pytorch---使用Pytorch实现U-Net进行语义分割

    一 代码中的数据集可以通过以下链接获取 百度网盘提取码 xff1a f1j7 二 代码运行环境 Pytorch gpu 61 61 1 10 1 Python 61 61 3 8 三 数据集处理代码如下所示 span class token