1 UNet网络架构


UNet网络由左编码部分,右解码部分和下两个卷积+激活层组成

  1. 编码部分

    • 从图中可知:架构中是由4个重复结构组成:2个3x3卷积层,非线形ReLU层和一个stride为2的2x2 max pooling层(图中的蓝箭头,红箭头)
    • 每一次下采样特征通道的数量加倍
  2. 解码部分

    • 和编码层类似,反卷积也有4个重复结构组成
    • 每个重复结构前先使用反卷积,每次反卷积后特征通道数量减半,特征图大小加倍(绿箭头)
    • 反卷积之后,反卷积的结果和编码部分对应步骤的特征图拼接起来(白/蓝块)
    • 如果编码部分的特征图尺寸较大,需要进行裁剪后再拼接(左边深蓝色的虚线)
    • 拼接后的特征图再进行2次3x3的卷积(右侧蓝箭头)
    • 最后一层的卷积核为1x1 的卷积核,将64通道的特征图转化为特定类别数量(分类数量)的结果(青色箭头)

2 模型构建

2.1 数据集获取

from PIL import ImageOps
from tensorflow import keras
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, Cropping2D, Concatenate
from tensorflow.keras.layers import Lambda, Activation, BatchNormalization, Dropout
from tensorflow.keras.models import Model
import random
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 图像位置
input_dir = 'segdata/images/'
# 图像路径
input_img_path = sorted([os.path.join(input_dir, fname)for fname in os.listdir(input_dir) if fname.endswith('.jpg')])
# 标注信息
target_dir = 'segdata/annotations/trimaps/'
# 标注路径
target_img_path = sorted(os.path.join(target_dir, fname) for fname in os.listdir(target_dir) if fname.endswith('.png') and not fname.startswith('.'))img_size = (160, 160)
batch_size = 32
num_classes = 4

  使用的数据集是Oxford-IIIT Pet Dataset宠物图像分割数据集,包含37种宠物类别,其中有12种猫的类别和25种狗的类别,每个类别大约有200张图片,所有图像都具有品种,头部ROI和像素级分割的标注。

2.2 构建数据集生成器

# 1、创建数据集生成器
class OxfordPets(keras.utils.Sequence):# 初始化def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):# 批次大小self.batch_size = batch_size# 图像尺寸self.img_size = img_size# 输入图像的路径self.input_img_paths = input_img_paths# 目标值路径self.target_img_paths = target_img_paths# 迭代次数def __len__(self):return len(self.target_img_paths) // self.batch_size# 获取batch数据def __getitem__(self, idx):# 获取该批次对应的样本索引i = idx * self.batch_size# 获取该批次数据batch_input_img_paths = self.input_img_paths[i:i + self.batch_size]batch_target_img_paths = self.target_img_paths[i:i + self.batch_size]# 构建特征值x = np.zeros((batch_size,) + self.img_size + (3,), dtype='float32')for j, path in enumerate(batch_input_img_paths):img = load_img(path, target_size=self.img_size)x[j] = img# 构建目标值y = np.zeros((batch_size,) + self.img_size + (1,), dtype='uint8')for j, path in enumerate(batch_target_img_paths):img = load_img(path, target_size=self.img_size,color_mode='grayscale')y[j] = np.expand_dims(img, 2)return x, y

2.3 编码部分

编码部分的特点是:

  • 架构中是由4个重复结构组成:2个3x3卷积层,非线形ReLU层和一个stride为2的2x2 max pooling层

  • 每一次下采样后我们都把特征通道的数量加倍

  • 每次重复都有两个输出:一个用于编码部分进行特征提取,一个用于解码部分的特征融合

# 2、编码
def downsampling_block(input_tensor, filters):# 输入:input_tensor,通道数:filters# 卷积x = Conv2D(filters, kernel_size=3, padding='same')(input_tensor)# BN层x = BatchNormalization()(x)# 激活函数x = Activation('relu')(x)# 卷积层x = Conv2D(filters, kernel_size=3, padding='same')(x)# BN层x = BatchNormalization()(x)# 激活x = Activation('relu')(x)# 返回:池化后的值以及激活未池化的值,激活未池化的值用于解码部分特征合并return MaxPooling2D(pool_size=2)(x), x

2.4 解码部分

  • 和编码层类似,反卷积也有4个重复结构组成
  • 每个重复结构前先使用反卷积,每次反卷积后特征通道数量减半,特征图大小加倍(绿箭头)
  • 反卷积之后,反卷积的结果和编码部分对应步骤的特征图拼接起来(白/蓝块)
  • 如果编码部分的特征图尺寸较大,需要进行裁剪后再拼接(左边深蓝色的虚线)
  • 拼接后的特征图再进行2次3x3的卷积(右侧蓝箭头)
  • 最后一层的卷积核为1x1 的卷积核,将64通道的特征图转化为特定类别数量(分类数量)的结果(青色箭头)
# 3、解码
def upsampling_block(input_tensor, skip_tensor, filters):# input_tensor:输入特征层,skip_tensor:编码部分的特征图,filters:通道数# 反卷积x = Conv2DTranspose(filters, kernel_size=2, strides=2,padding='same')(input_tensor)# 获取反卷积后特征图尺寸_, x_height, x_width, _ = x.shape# 获取编码部分激活未池化特征图尺寸_, s_height, s_width, _ = skip_tensor.shape# 计算差异h_crop = s_height - x_heightw_crop = s_width - x_width# 判断是否进行裁剪if h_crop == 0 and w_crop == 0:y = skip_tensorelse:# 获取裁剪的大小cropping = ((h_crop // 2, h_crop - h_crop // 2),(w_crop // 2, w_crop - w_crop // 2))y = Cropping2D(cropping=cropping)(skip_tensor)# 特征融合x = Concatenate()([x, y])# 卷积x = Conv2D(filters, kernel_size=3, padding='same')(x)# BNx = BatchNormalization()(x)# 激活层x = Activation('relu')(x)# 卷积x = Conv2D(filters, kernel_size=2, padding='same')(x)# BNx = BatchNormalization()(x)# 激活层x = Activation('relu')(x)return x

2.5 模型构建


将编码部分和解码部分组合一起,就可构建UNet网络,在这里UNet网络的深度通过depth进行设置,并设置第一个编码模块的卷积核个数通过filter进行设置,通过以下模块将编码和解码部分进行组合:

# 4、unet网络
def unet(imagesize, classes, fetures=64, depth=3):# 定义输入inputs = keras.Input(shape=(imagesize + (3,)))x = inputs# 编码部分skips = []for i in range(depth):x, x0 = downsampling_block(x, fetures)skips.append(x0)fetures *= 2# 卷积x = Conv2D(filters=fetures, kernel_size=3, padding='same')(x)# BNx = BatchNormalization()(x)# 激活x = Activation('relu')(x)# 卷积x = Conv2D(filters=fetures, kernel_size=3, padding='same')(x)# 激活x = Activation('relu')(x)# 解码部分(调转顺序,将激活未池化特征值与反卷积层融合)for i in reversed(range(depth)):fetures //= 2# 输入,激活未池化特征值,通道数x = upsampling_block(x, skips[i], fetures)# 1x1卷积x = Conv2D(filters=classes, kernel_size=1, padding='same')(x)# 激活outputs = Activation('softmax')(x)return keras.Model(inputs=inputs, outputs=outputs)
# 实例化网络模型
model = unet(img_size, num_classes)

2.6 模型训练

2.6.1 数据集划分

数据集中的图像是按顺序进行存储的,在这里我们将数据集打乱后,验证集的数量1200,剩余的为训练集,划分训练集和验证集:

# 验证集数量设置
val_samples = 1200
# 打乱数据集(随机数种子设置)
random.Random(1).shuffle(input_img_path)
random.Random(1).shuffle(target_img_path)
# 划分数据集
# 训练集
train_input_img_paths = input_img_path[:-val_samples]
train_target_img_paths = target_img_path[:-val_samples]
# 验证集
val_input_img_paths = input_img_path[-val_samples:]
val_target_img_paths = target_img_path[-val_samples:]

2.6.2 数据集获取

train_gen = OxfordPets(batch_size, img_size, train_input_img_paths, train_target_img_paths)
val_gen = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)

2.6.3 模型编译和训练

model.compile(optimize='rmprop', loss='sparse_categorical_crossentropy')
model.fit(train_gen, epochs=10, validation_data=val_gen)

2.7模型预测

# 获取验证集数据,并进行预测
val_gen = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)
val_preds = model.predict(val_gen)
# 定义预测结果显示的方法
# 图像显示
def display_mask(i):# 获取到第i个样本的预测结果mask = np.argmax(val_preds[i], axis=-1)# 维度调整mask = np.expand_dims(mask, axis=-1)# 转换为图像,并进行显示img = PIL.ImageOps.autocontrast(keras.preprocessing.image.array_to_img(mask))display(img)
# 选择某一个图像进行预测
# 选中验证集的第10个图像
i = 10
# 原图像展示
# 输入图像显示
display(Image(filename=val_input_img_paths[i]))

# 目标值展示
# 真实值显示
img = PIL.ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)

# 模型预测结果
# 显示预测结果
display_mask(i)

深度学习:目标分割|UNet网络模型及案例实现相关推荐

  1. 深度学习语义分割理论与实战指南

    本文来自微信公众号[机器学习实验室] 深度学习语义分割理论与实战指南 1 语义分割概述 2 关键技术组件 2.1 编码器与分类网络 2.2 解码器与上采样 2.2.1 双线性插值(Bilinear I ...

  2. 姿态检测 树莓派_怎样在树莓派上轻松实现深度学习目标检测?

    原标题:怎样在树莓派上轻松实现深度学习目标检测? 雷锋网按:本文为 AI 研习社编译的技术博客,原标题 How to easily Detect Objects with Deep Learning ...

  3. 值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(下)

    作者 | 黄浴 来源 | 转载自知乎专栏自动驾驶的挑战和发展 [导读]在近日发布的<值得收藏!基于激光雷达数据的深度学习目标检测方法大合集(上)>一文中,作者介绍了一部分各大公司和机构基于 ...

  4. 【NLP】博士笔记 | 深入理解深度学习语义分割

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自|机器学习初学者 引言:最近自动驾驶项目需要学习一些语义分 ...

  5. 笔记 | 深入理解深度学习语义分割

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:计算机视觉联盟 本文内容概述王博Kings最近的语义分割 ...

  6. 博士笔记 | 深入理解深度学习语义分割

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达本文转自|机器学习初学者 本文内容概述王博Kings最近的语义分割学 ...

  7. 【深度学习】深度学习语义分割理论与实战指南.pdf

    图像分类.目标检测和图像分割是基于深度学习的计算机视觉三大核心任务.三大任务之间明显存在着一种递进的层级关系,图像分类聚焦于整张图像,目标检测定位于图像具体区域,而图像分割则是细化到每一个像素. 基于 ...

  8. 深度学习语义分割理论与实战指南.pdf

    深度学习语义分割理论与实战指南 V1.0 版本已经完成,主要包括语义分割概述.关键技术组件.数据模块.经典分割网络与架构.PyTorch基本实战方法等五个部分. 获取方式: 扫描关注下方公众号回复 语 ...

  9. 深度学习目标检测 RCNN F-RCNN SPP yolo-v1 v2 v3 残差网络ResNet MobileNet SqueezeNet ShuffleNet

    深度学习目标检测--结构变化顺序是RCNN->SPP->Fast RCNN->Faster RCNN->YOLO->SSD->YOLO2->Mask RCNN ...

  10. 深度学习目标检测之RCNN、SPP-net、Fast RCNN、Faster RCNN

    一.目标检测介绍 目标检测(目标提取)是一种基于目标几何和统计特征的图像分割,将目标的分割和识别合二为一,主要是明确从图中看到了什么物体.它们分别在什么位置.传统的目标检测方法一般分为三个阶段:首先在 ...

最新文章

  1. 数据中心运营之4P标准化运维规程
  2. ASP.NET中过滤HTML字符串的两个方法
  3. DataTable的Merge\COPY\AcceptChange使用说明
  4. 自己的阿里云部署了django发现连不上, 一下是网上查的解决方法,亲测可用
  5. Atitit.web 视频播放器classid clsid 大总结quicktime,vlc 1. Classid的用处。用来指定播放器 1 2. object 标签用于包含对象,比如图像、音
  6. 符合 Qi 规范的移动设备无线充电解决方案
  7. 【算法】1436. 旅行终点站(多语言实现)
  8. curl DNS解析失败crash问题
  9. 生活大爆炸版石头剪刀布-简单模拟
  10. 虚拟内存设置调整图解
  11. 网络号和主机号具体计算原理-ipv4篇
  12. 【办公协作软件】万彩办公大师教程丨图片OCR工具的应用
  13. 2018-11-5-win10-uwp-异步转同步
  14. 【bzoj3698】【XWW的难题】【有上下界的网络流】
  15. PDF如何设置注释字体大小
  16. LTspice使用教程笔记
  17. BZOJ 3165 Heoi2013 Segment 线段树
  18. 车厢调度(4种方法)
  19. Ubuntu16.04 LTS安装友善之臂smart4418交叉编译器
  20. MFC MDI 遍历打开的所有文档

热门文章

  1. 数据凌乱,埋点差,难以归因?数据治理有妙招!
  2. 11.4.1 CURDATE()函数
  3. MySQL 导出 CSV 乱码
  4. python生成csv乱码
  5. redhat 7中配置与管理WEB服务器
  6. 19c RAC Duplicate方式静默安装ADG从库
  7. 打开word文档提示文件未找到_打开CAD图纸或文档提示缺少SHX文件,2850种CAD字体大全资源分享...
  8. 基于FPGA的浮点数运算
  9. 关于微信跳转H5页面,背景色显示灰色问题
  10. SkeyeVSS垃圾回收站视频智能分析系统助力垃圾分类 共享美好生活