@Author:Runsen

在图像领域,除了分类,CNN 今天还用于更高级的问题,如图像分割、对象检测等。图像分割是计算机视觉中的一个过程,其中图像被分割成代表图像中每个不同类别的不同段。

上面图片一段代表猫,另一段代表背景。

从自动驾驶汽车到卫星,图像分割在许多领域都很有用。其中最重要的是医学成像。

UNet 是一种卷积神经网络架构,在 CNN 架构几乎没有变化的情况下进行了扩展。它的发明是为了处理生物医学图像,其目标不仅是对是否存在感染进行分类,而且还要识别感染区域。

UNet

论文:https://arxiv.org/abs/1505.04597

UNet结构看起来像一个“U”,该架构由三部分组成:收缩部分、瓶颈部分和扩展部分。收缩段由许多收缩块组成。每个块接受一个输入,应用两个 3X3 卷积层,然后是 2X2 最大池化。每个块之后的内核或特征图的数量加倍,以便架构可以有效地学习复杂的结构。最底层介于收缩层和膨胀层之间。它使用两个 3X3 CNN 层,然后是 2X2 上卷积层。

每个块将输入传递给两个 3X3 CNN 层,然后是一个 2X2 上采样层。同样在每个块之后,卷积层使用的特征图数量减半以保持对称性。然而,每次输入也会附加相应收缩层的特征图。此操作将确保在收缩图像时学习的特征将用于重建它。扩展块的数量与收缩块的数量相同。之后,生成的映射通过另一个 3X3 CNN 层,特征映射的数量等于所需的片段数量。

torch实现

使用的数据集是:https://www.kaggle.com/paultimothymooney/chiu-2015

这个数据集用于分割糖尿病性黄斑水肿的光学相干断层扫描图像的图像。

对于mat的数据,使用scipy.io.loadmat进行加载

下面使用 Pytorch 框架实现了 UNet 模型,代码来源下面的Github:https://github.com/Hsankesara/DeepResearch

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optimclass UNet(nn.Module):def contracting_block(self, in_channels, out_channels, kernel_size=3):block = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),torch.nn.ReLU(),torch.nn.BatchNorm2d(out_channels),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),torch.nn.ReLU(),torch.nn.BatchNorm2d(out_channels),)return blockdef expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):block = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))return  blockdef final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):block = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),torch.nn.ReLU(),torch.nn.BatchNorm2d(mid_channel),torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),torch.nn.ReLU(),torch.nn.BatchNorm2d(out_channels),)return  blockdef __init__(self, in_channel, out_channel):super(UNet, self).__init__()#Encodeself.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)self.conv_encode2 = self.contracting_block(64, 128)self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)self.conv_encode3 = self.contracting_block(128, 256)self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)# Bottleneckself.bottleneck = torch.nn.Sequential(torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),torch.nn.ReLU(),torch.nn.BatchNorm2d(512),torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),torch.nn.ReLU(),torch.nn.BatchNorm2d(512),torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1))# Decodeself.conv_decode3 = self.expansive_block(512, 256, 128)self.conv_decode2 = self.expansive_block(256, 128, 64)self.final_layer = self.final_block(128, 64, out_channel)def crop_and_concat(self, upsampled, bypass, crop=False):if crop:c = (bypass.size()[2] - upsampled.size()[2]) // 2bypass = F.pad(bypass, (-c, -c, -c, -c))return torch.cat((upsampled, bypass), 1)def forward(self, x):# Encodeencode_block1 = self.conv_encode1(x)encode_pool1 = self.conv_maxpool1(encode_block1)encode_block2 = self.conv_encode2(encode_pool1)encode_pool2 = self.conv_maxpool2(encode_block2)encode_block3 = self.conv_encode3(encode_pool2)encode_pool3 = self.conv_maxpool3(encode_block3)# Bottleneckbottleneck1 = self.bottleneck(encode_pool3)# Decodedecode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)cat_layer2 = self.conv_decode3(decode_block3)decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)cat_layer1 = self.conv_decode2(decode_block2)decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)final_layer = self.final_layer(decode_block1)return  final_layer

上面代码中的 UNet 模块代表了 UNet 的整个架构。contraction_block和expansive_block分别用于创建收缩段和膨胀段。该函数crop_and_concat将收缩层的输出与新的扩展层输入相加。

unet = Unet(in_channel=1,out_channel=2)
#out_channel represents number of segments desired
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)
optimizer.zero_grad()
outputs = unet(inputs)
# permute such that number of desired segments would be on 4th dimension
outputs = outputs.permute(0, 2, 3, 1)
m = outputs.shape[0]
# Resizing the outputs and label to caculate pixel wise softmax loss
outputs = outputs.resize(m*width_out*height_out, 2)
labels = labels.resize(m*width_out*height_out)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

对于该数据集解决标准教程代码:https://www.kaggle.com/hsankesara/unet-image-segmentation

【小白学习PyTorch教程】十九、 基于torch实现UNet 图像分割模型相关推荐

  1. 【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型

    「@Author:Runsen」 当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词.这可以称为记忆.卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网 ...

  2. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

  3. 【小白学习PyTorch教程】十、基于大型电影评论数据集训练第一个LSTM模型

    「@Author:Runsen」 本博客对原始IMDB数据集进行预处理,建立一个简单的深层神经网络模型,对给定数据进行情感分析. 数据集下载 here. 原始数据集,没有进行处理here. impor ...

  4. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

  5. 【小白学习PyTorch教程】十六、在多标签分类任务上 微调BERT模型

    @Author:Runsen BERT模型在NLP各项任务中大杀四方,那么我们如何使用这一利器来为我们日常的NLP任务来服务呢?首先介绍使用BERT做文本多标签分类任务. 文本多标签分类是常见的NLP ...

  6. 【小白学习PyTorch教程】四、基于nn.Module类实现线性回归模型

    「@Author:Runsen」 上次介绍了顺序模型,但是在大多数情况下,我们基本都是以类的形式实现神经网络. 大多数情况下创建一个继承自 Pytorch 中的 nn.Module 的类,这样可以使用 ...

  7. 【小白学习PyTorch教程】十四、迁移学习:微调ResNet实现男人和女人图像分类

    「@Author:Runsen」 上次微调了Alexnet,这次微调ResNet实现男人和女人图像分类. ResNet是 Residual Networks 的缩写,是一种经典的神经网络,用作许多计算 ...

  8. 【小白学习PyTorch教程】七、基于乳腺癌数据集​​构建Logistic 二分类模型

    「@Author:Runsen」 在逻辑回归中预测的目标变量不是连续的,而是离散的.可以应用逻辑回归的一个示例是电子邮件分类:标识为垃圾邮件或非垃圾邮件.图片分类.文字分类都属于这一类. 在这篇博客中 ...

  9. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型

    @Author:Runsen 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先了解 ...

最新文章

  1. 【Python】str转datetime与datetime计算
  2. 期末复习、化学反应工程科目(第一章)
  3. 广西大学计算机类开设课程,操作系统教学大纲-广西大学计算机与电子信息学院.DOC...
  4. 带有AngularJS资源的Spring Rest Controller
  5. 奥鹏19春计算机应用基础,19春西南交《计算机应用基础》在线作业二(答案)-...
  6. CCF201903-4 消息传递接口(100分)【模拟】
  7. 如何将Safari中保存的密码导入Chrome ?
  8. VS编译时自动引用Debug|Release版本的dll
  9. Java后端开发技术选型
  10. 园林计算机制图在计算机上的应用,园林计算机制图
  11. 漫威超级英雄大全(二)
  12. python可以手眼定标吗_机器人无标定手眼协调
  13. 网络服务器未运行是什么原因是,Win7系统网络诊断提示诊断策略服务未运行怎么办?...
  14. 专精特新小巨人企业是什么
  15. librosa重采样和声道设置
  16. 史上最全的CTF保姆教程 从入门到入狱【带工具】
  17. [Python图像处理] 二十九.MoviePy视频编辑库实现抖音短视频剪切合并操作
  18. 神经网络基础学习小记
  19. Python基础----Socket编程规范及底层原理(三)---socketserver实现并发及底层原理
  20. 京津冀计算机考研院校2021与2022招生人数对比

热门文章

  1. git / 通过 ssh 与仓库通信
  2. java找到项目下的某个文件夹_servlet 得到 JavaWeb项目下某文件夹的路径
  3. php cros跨域处理,php_CORS 跨域
  4. list中抽出某一个字段的值_Java的stream代替List解决单线程等问题
  5. “在解决方案中的一个或多个项目由于以下原因未能加载 项目文件或网站已移动或重新命名,或者不在您的计算机上” 的解决办法...
  6. 微信支付宝 支付单文件操作
  7. chromedriver与chrome版本映射表
  8. Groovy 设计模式 -- 保镖模式
  9. BZOJ 4448 主席树+树链剖分(在线)
  10. php随机生成验证码代码