1.概述

最近有时间,跑了一下UNet模型,因为自己的深度学习基础不扎实,导致用了一些时间。目前只停留在使用和理解别人模型的基础上,对于优化模型的相关方法还有待学习。
众所周知,UNent是进行语义分割的知名模型,它的U形结构很多人也都见过,但是如果自己没有亲自试过的话,也就只知道它的U形结构,其实里面还是有很多学问的,下面就把自己学习时候的一些理解写一下。
最后会拿个完整代码作为例子(实际上自己练习了两个比较成功的例子)

2.UNet模型理解

先放UNet模型的图,然后介绍再Pytorch相关实现的函数。

一般看到这个图,都会看到它从左边逐渐编码,到最底端,之后从底端不断解码,恢复为一张图像。但是很多人会忽略中间的从左往右的三条灰色的直线,这是它把一个图像处理为目标图像的一个关键。
从图中理解,Unet模型可以分为几个关键的部分
①ConvBlock():也就是U形结构左边的这一部分,当然也包括下面和它很相似的三个,他们都是(卷积-激活-卷积-激活),这也就是图像从572×572变成568×568的原因,实际上就是卷积过程中边界设定的问题。
②MaxPool2d():然后紧接者就是红色箭头了,它就是池化,把卷积后的特征筛选出来,这是图像尺寸急剧下降的原因。
接着①和②步骤重复了4次。得到了一个512维,32×32的图片。当然有的模型为了简化,这个步骤只进行了3次,实际上也能取得不错的效果。
最后它又进行了一次①,成功把图像变成1024维度,30×30,图像特征被压缩在了这样的一个信息里面了。
下面就是U形结构的右侧上升部分了,总之就是把刚才被压缩的信息展开,结合一步步压缩过程的中间图像(灰色箭头),把图像还原成想要的样子。
这里主要用到了下面的函数
③ConvBlock():和①一模一样,只不过,①是把维度逐渐变大也就是从3变成1024,而在经过4次③之后,1024变成了1
ConvTranspose2d():与池化对应,这个是上采样(不确定叫法是否得当)过程,它拓展了数据的尺寸,并减少了维度。
⑤copyAndCrop和cat,这个就是灰色箭头的实现过程,首先把两个输入数据(也就是原始数据及编码过程数据和上采样结果裁剪为一致的尺寸,之后进行连接)
在最后一层,输出的维度为一,也就是灰度图像,不过也可以定义为其他维度,例如输出彩色,这跟自己实际的需求有关。

3.数据集加载

为了方便下面展示代码,先导入必要的模块

import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.utils as vutils
from torchsummary import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, StepLR, MultiStepLR, CyclicLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T, datasets as dset
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from zipfile import ZipFile
from tqdm import tqdm
from glob import glob
from PIL import Image
import cv2
from torch.utils.tensorboard import SummaryWriter

数据加载过程跟普通的卷积神经网络没什么区别,无非就是构建Dataset从文件夹读取数据,之后构建dataset_loader,用来为训练数据做准备。下面直接放一段代码。
不过还是要介绍一下原数据的目录结构,下载数据请点这里Carvana Image Masking Challenge.,下载下面这两个压缩包解压就可以

解压后这样就行:

在下面的代码中,数据被分为了两组,0.7的数据被作为训练组,0.3的数据被用来验证。
多说两句,这个数据集的构建继承了Dataset类,实现了__getitem__(self, index: int)和__len__(self)两个函数。

class MyDataset(Dataset):def __init__(self, root_dir: str, train=True, transforms=None):super(MyDataset, self).__init__()self.train = trainself.transforms = transformsfile_path = root_dir + 'imgs/*.jpg'file_mask_path = root_dir + 'masks/*.gif'self.images = sorted(glob(file_path))self.image_mask = sorted(glob(file_mask_path))# manually split the train/valid datasplit_ratio = int(len(self.images) * 0.7)if train:self.images = self.images[:split_ratio]self.image_mask = self.image_mask[:split_ratio]else:self.images = self.images[split_ratio:]self.image_mask = self.image_mask[split_ratio:]def __getitem__(self, index: int):image = Image.open(self.images[index]).convert('RGB')image_mask = Image.open(self.image_mask[index]).convert('L')if self.transforms:image = self.transforms(image)image_mask = self.transforms(image_mask)return {'img': image, 'mask': image_mask}def __len__(self):return len(self.images)

4.构建模型

下面模型的代码跟**2.**的介绍是对应的,建议对照看一下,就会有所理解

class ConvBlock(nn.Module):def __init__(self, in_channels: int, out_channels: int):super(ConvBlock, self).__init__()self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x: torch.Tensor):return self.block(x)class CopyAndCrop(nn.Module):def forward(self, x: torch.Tensor, encoded: torch.Tensor):_, _, h, w = encoded.shapecrop = T.CenterCrop((h, w))(x)output = torch.cat((x, crop), 1)return outputclass UNet(nn.Module):def __init__(self, in_channels: int, out_channels: int):super(UNet, self).__init__()self.encoders = nn.ModuleList([ConvBlock(in_channels, 64),ConvBlock(64, 128),ConvBlock(128, 256),ConvBlock(256, 512),])self.down_sample = nn.MaxPool2d(2)self.copyAndCrop = CopyAndCrop()self.decoders = nn.ModuleList([ConvBlock(1024, 512),ConvBlock(512, 256),ConvBlock(256, 128),ConvBlock(128, 64),])# PixelShuffle, UpSample will modify the output channel (you can add extra operation to update the channel, e.g.conv2d)# preffer use convTranspose2d, it won't modify the output channelself.up_samples = nn.ModuleList([nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)])self.bottleneck = ConvBlock(512, 1024)self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1)def forward(self, x: torch.Tensor):# encodencoded_features = []for enc in self.encoders:x = enc(x)encoded_features.append(x)x = self.down_sample(x)x = self.bottleneck(x)# decodefor idx, denc in enumerate(self.decoders):x = self.up_samples[idx](x)encoded = encoded_features.pop()x = self.copyAndCrop(x, encoded)x = denc(x)output = self.final_conv(x)return output

5.模型训练

模型训练大约包含下面几个步骤,首先定义了几个必要的参数,例如图像大小,batch_size,device 等等。流程如下。没有介绍优化器和损失函数之类的,因为笔者自己理解还不够,但是代码里面是有的。代码里面有些绘图的内容,方便了可视化,感觉麻烦可以删掉。

Created with Raphaël 2.3.0 定义参数 加载数据(MyDataset) 创建dataset_loader 开始训练 训练集训练 验证 训练集训练 保存模型 达到epochs数量? 结束 yes no
batch_size = 4
n_iters = 10000
epochs = 10
learning_rate = 0.0002
n_workers = 2
width = 256
height = 256
channels = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 44
random.seed(seed)
torch.manual_seed(seed)if __name__ == '__main__':transforms = T.Compose([T.Resize((width, height)),T.ToTensor(),#     T.Normalize(mean=[0.485, 0.456, 0.406],#                 std=[0.229, 0.224, 0.225]),#     T.RandomHorizontalFlip()])train_dataset = MyDataset(root_dir='./data/',train=True,transforms=transforms)val_dataset = MyDataset(root_dir='./data/',train=False,transforms=transforms)train_dataset_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=n_workers)val_dataset_loader = DataLoader(dataset=val_dataset,batch_size=batch_size,shuffle=True,num_workers=n_workers)samples = next(iter(train_dataset_loader))fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 4))fig.tight_layout()ax1.axis('off')ax1.set_title('input image')ax1.imshow(np.transpose(vutils.make_grid(samples['img'], padding=2).numpy(),(1, 2, 0)))ax2.axis('off')ax2.set_title('input mask')ax2.imshow(np.transpose(vutils.make_grid(samples['mask'], padding=2).numpy(),(1, 2, 0)), cmap='gray')plt.show()def dice_score(pred: torch.Tensor, mask: torch.Tensor):dice = (2 * (pred * mask).sum()) / (pred + mask).sum()return np.mean(dice.cpu().numpy())def iou_score(pred: torch.Tensor, mask: torch.Tensor):passdef plot_pred_img(samples, pred):fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(12, 6))fig.tight_layout()ax1.axis('off')ax1.set_title('input image')ax1.imshow(np.transpose(vutils.make_grid(samples['img'], padding=2).numpy(),(1, 2, 0)))ax2.axis('off')ax2.set_title('input mask')ax2.imshow(np.transpose(vutils.make_grid(samples['mask'], padding=2).numpy(),(1, 2, 0)), cmap='gray')ax3.axis('off')ax3.set_title('predicted mask')ax3.imshow(np.transpose(vutils.make_grid(pred, padding=2).cpu().numpy(),(1, 2, 0)), cmap='gray')plt.show()def plot_train_progress(model):#     model.eval()#     with torch.no_grad():samples = next(iter(val_dataset_loader))val_img = samples['img'].to(device)val_mask = samples['mask'].to(device)pred = model(val_img)plot_pred_img(samples, pred.detach())def train(model, optimizer, criteration, scheduler=None):train_losses = []val_lossess = []lr_rates = []# calculate train epochsepochs = int(n_iters / (len(train_dataset) / batch_size))for epoch in range(epochs):model.train()train_total_loss = 0train_iterations = 0for idx, data in enumerate(tqdm(train_dataset_loader)):train_iterations += 1train_img = data['img'].to(device)train_mask = data['mask'].to(device)optimizer.zero_grad()# speed up the trainingwith torch.set_grad_enabled(True):train_output_mask = model(train_img)train_loss = criterion(train_output_mask, train_mask)train_total_loss += train_loss.item()train_loss.backward()optimizer.step()train_epoch_loss = train_total_loss / train_iterationstrain_losses.append(train_epoch_loss)# evaluate modemodel.eval()with torch.no_grad():val_total_loss = 0val_iterations = 0scores = 0for vidx, val_data in enumerate(tqdm(val_dataset_loader)):val_iterations += 1val_img = val_data['img'].to(device)val_mask = val_data['mask'].to(device)with torch.set_grad_enabled(False):pred = model(val_img)val_loss = criterion(pred, val_mask)val_total_loss += val_loss.item()scores += dice_score(pred, val_mask)val_epoch_loss = val_total_loss / val_iterationsdice_coef_scroe = scores / val_iterationsval_lossess.append(val_epoch_loss)plot_train_progress(model)print('epochs - {}/{} [{}/{}], dice score: {}, train loss: {}, val loss: {}'.format(epoch + 1, epochs,idx + 1, len(train_dataset_loader),dice_coef_scroe, train_epoch_loss, val_epoch_loss))torch.save(model, 'modelCar' + str(epoch) + '.pkl')lr_rates.append(optimizer.param_groups[0]['lr'])if scheduler:scheduler.step()  # decay learning rateprint('LR rate:', scheduler.get_last_lr())return {'lr': lr_rates,'train_loss': train_losses,'valid_loss': val_lossess}model = UNet(in_channels=3, out_channels=1).to(device)criterion = nn.BCEWithLogitsLoss()# criterion = smp.losses.DiceLoss(mode='binary')optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)history = train(model, optimizer, criterion)

6.训练结果

经过一段时间的训练,在这个数据集中的效果还不错,下面放几张图来看一下,展示3个epochs的吧。(dic score的计算可能并不准确)
100%|██████████| 891/891 [07:31<00:00, 1.97it/s]
100%|██████████| 382/382 [01:20<00:00, 4.76it/s]
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
epochs - 1/11 [891/891], dice score: -1.732437749183615, train loss: 0.10204851534531173, val loss: 0.037313264545969935

100%|██████████| 891/891 [07:29<00:00, 1.98it/s]
100%|██████████| 382/382 [01:15<00:00, 5.07it/s]
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
epochs - 2/11 [891/891], dice score: -1.2491931439382244, train loss: 0.027105745536460217, val loss: 0.021171443983522387

100%|██████████| 891/891 [07:30<00:00, 1.98it/s]
100%|██████████| 382/382 [01:15<00:00, 5.08it/s]
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
epochs - 3/11 [891/891], dice score: -1.2030698774060653, train loss: 0.01996645850143014, val loss: 0.02303329437815082

UNet语义分割模型的使用-Pytorch相关推荐

  1. 语义分割重制版1——Pytorch 搭建自己的Unet语义分割平台

    转载:https://blog.csdn.net/weixin_44791964/article/details/108866828?spm=1001.2014.3001.5501 对应b站视频:ht ...

  2. 憨批的语义分割重制版6——Pytorch 搭建自己的Unet语义分割平台

    憨批的语义分割重制版6--Pytorch 搭建自己的Unet语义分割平台 注意事项 学习前言 什么是Unet模型 代码下载 Unet实现思路 一.预测部分 1.主干网络介绍 2.加强特征提取结构 3. ...

  3. 深度学习-Tensorflow2.2-图像处理{10}-UNET图像语义分割模型-24

    UNET图像语义分割模型简介 代码 import tensorflow as tf import matplotlib.pyplot as plt %matplotlib inline import ...

  4. 语义分割系列6-Unet++(pytorch实现)

    目录 Unet++网络 Dense connection deep supervision 模型复现 Unet++ 数据集准备 模型训练 训练结果 Unet++:<UNet++: A Neste ...

  5. 图像语义分割模型综述

    文章目录 一.语义分割介绍 二.语义分割的思路 空洞卷积 条件随机场 三.经典语义分割算法介绍 1.FCN 2.UNet Family (1)UNet (2)Attention U-Net (3)UN ...

  6. FCN与U-Net语义分割算法

    FCN与U-Net语义分割算法 图像语义分割(Semantic Segmentation)是图像处理和是机器视觉技术中关于图像理解的重要一环,也是 AI 领域中一个重要的分支.语义分割即是对图像中每一 ...

  7. 人人必须要知道的语义分割模型:DeepLabv3+

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 前言 图像分割是计算机视觉中除了分类和检测外的另一项基本任务,它意 ...

  8. 【深度学习】SETR:基于视觉 Transformer 的语义分割模型

    Visual Transformer Author:louwill Machine Learning Lab 自从Transformer在视觉领域大火之后,一系列下游视觉任务应用研究也随之多了起来.基 ...

  9. u-net语义分割_使用U-Net的语义分割

    u-net语义分割 Picture By Martei Macru On Unsplash 图片由Martei Macru On Unsplash拍摄 Semantic segmentation is ...

最新文章

  1. 【深度学习】深入浅出神经网络框架的模型元件(池化、正则化和反卷积层)
  2. Python开发【第七篇】: 面向对象和模块补充
  3. 到底什么样的ABAP系统能运行Fiori应用
  4. Spring @Value取值为null或@Autowired注入失败
  5. 域控制器的强制卸载,Active Directory系列之十四
  6. 完全清除一个带包的项目文件的方法
  7. 小程序 怎么选云服务器,小程序如何选择云服务器
  8. paip.android 手机输入法制造大法
  9. 乐队设备--反馈抑制器学习笔记
  10. Windows 远程桌面无法复制粘贴问题
  11. [蓝桥杯2018决赛]阅兵方阵
  12. ​从ASML年报看半导体产业的未来
  13. 多层板的板层布局和线宽的设置(记录)
  14. 【HCIE-BigData-Data Mining课程笔记(三)】预备知识-Python基础
  15. 计算机组装手机app,智能手机安装软件的
  16. linux基础命令学习笔记
  17. 网络安全——DDOS攻击
  18. 架构设计参考项目系列主题:智能风控决策引擎系统可落地实现方案:风控监控大盘实现
  19. 淘宝主图SKU图采集下载
  20. 试卷模板 html,A4纸试卷模板.doc

热门文章

  1. Nova的安装及其配置
  2. 电脑Tab键有什么功能?分享Tab键的6个妙用
  3. 什么叫少儿机器人编程
  4. 编写虚拟 AI 女友
  5. 关于AD20的PCB电路图打印设置
  6. 全国青少年编程等级考试python一级真题2020年12月(含题库答题软件账号)
  7. 疫情可视化part1
  8. 有创意的思维导图怎么画
  9. 机器学习笔记~HDF5 library version mismatched error与ImportError: 'save_model' requires h5py问题解决
  10. 鸡兔同笼问题。上有头30个,下有脚90只,问鸡兔各有多少只。