一、U-Net简介

U-Net 是最初为医学影像分割而提出的一种语义分割技术。 它是较早的深度学习分割模型之一,U-Net 架构也用于许多 GAN 变体,例如 Pix2Pix 生成器。

U-Net 在论文 U-Net: Convolutional Networks for Biomedical Image Segmentation 中进行了介绍。 模型架构相当简单:一个编码器(用于下采样)和一个解码器(用于上采样),带有跳跃连接。 如图 1 所示,它的形状像字母 U,因此得名 U-Net。

二、数据集说明

我们将使用作为 TensorFlow 数据集 (TFDS) 的一部分提供的 Oxford-IIIT 宠物数据集。 它可以很容易地用 TFDS 加载,然后进行一些数据预处理,为训练分割模型做好准备。

可以使用 tfds 通过指定数据集的名称来加载数据集,并通过设置 with_info=True 来获取数据集信息:

代码如下,如果多次运行程序,第一次下载完之后可以添加download=False参数,会自动从已经下载好的文件夹下读取数据。

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

使用 print(info) 打印数据集信息,我们将看到牛津宠物数据集的各种详细信息。 例如,在图 2 中,我们可以看到共有 7349 张图像,其中包含内置的测试/训练拆分。

三、相关代码

1、unet模型

U-Net 的架构相当简单; 然而,为了在编码器和解码器之间创建跳跃连接,我们需要连接一些层。 所以 Keras 函数式 API 最适合这个目的。

首先,我们创建一个 build_unet_model 函数,指定输入、编码器层、瓶颈、解码器层,最后是带有激活 softmax 的 Conv2D 的输出层。 注意输入图像的形状是 128x128x3。 输出具有三个通道,对应于模型将为每个像素分类的三个类:背景、前景对象和对象轮廓。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np# 在编码器和U-Net的瓶颈中使用
def double_conv_block(x, n_filters):# Conv2D then ReLU activationx = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)# Conv2D then ReLU activationx = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)return x# 用于在编码器中进行下采样或特征提取
def downsample_block(x, n_filters):f = double_conv_block(x, n_filters)p = layers.MaxPool2D(2)(f)p = layers.Dropout(0.3)(p)return f, p# 上采样函数 upsample_block
def upsample_block(x, conv_features, n_filters):# upsamplex = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)# concatenatex = layers.concatenate([x, conv_features])# dropoutx = layers.Dropout(0.3)(x)# Conv2D twice with ReLU activationx = double_conv_block(x, n_filters)return x# 创建模型
def build_unet_model():# inputsinputs = layers.Input(shape=(128, 128, 3))# encoder: contracting path - downsample# 1 - downsamplef1, p1 = downsample_block(inputs, 64)# 2 - downsamplef2, p2 = downsample_block(p1, 128)# 3 - downsamplef3, p3 = downsample_block(p2, 256)# 4 - downsamplef4, p4 = downsample_block(p3, 512)# 5 - bottleneckbottleneck = double_conv_block(p4, 1024)# decoder: expanding path - upsample# 6 - upsampleu6 = upsample_block(bottleneck, f4, 512)# 7 - upsampleu7 = upsample_block(u6, f3, 256)# 8 - upsampleu8 = upsample_block(u7, f2, 128)# 9 - upsampleu9 = upsample_block(u8, f1, 64)# outputsoutputs = layers.Conv2D(3, 1, padding="same", activation="softmax")(u9)# unet model with Keras Functional APIunet_model = tf.keras.Model(inputs, outputs, name="U-Net")return unet_model

2、训练代码

运行train函数进行训练。

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True, download=False)
train_dataset = dataset["train"].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = dataset["test"].map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)BATCH_SIZE = 32
BUFFER_SIZE = 1000
train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validation_batches = test_dataset.take(3000).batch(BATCH_SIZE)
test_batches = test_dataset.skip(3000).take(669).batch(BATCH_SIZE)def train():unet_model = build_unet_model()unet_model.compile(optimizer=tf.keras.optimizers.Adam(),loss="sparse_categorical_crossentropy",metrics="accuracy")NUM_EPOCHS = 40TRAIN_LENGTH = info.splits["train"].num_examplesSTEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZEVAL_SUBSPLITS = 5TEST_LENTH = info.splits["test"].num_examplesVALIDATION_STEPS = TEST_LENTH // BATCH_SIZE // VAL_SUBSPLITSmodel_history = unet_model.fit(train_batches,epochs=NUM_EPOCHS,steps_per_epoch=STEPS_PER_EPOCH,validation_steps=VALIDATION_STEPS,validation_data=test_batches)unet_model.save('unet.h5')

3、其它函数

# 修改大小
def resize(input_image, input_mask):input_image = tf.image.resize(input_image, (128, 128), method="nearest")input_mask = tf.image.resize(input_mask, (128, 128), method="nearest")return input_image, input_mask# 水平翻转
def augment(input_image, input_mask):if tf.random.uniform(()) > 0.5:# Random flipping of the image and maskinput_image = tf.image.flip_left_right(input_image)input_mask = tf.image.flip_left_right(input_mask)return input_image, input_mask# 规范化数据集
def normalize(input_image, input_mask):input_image = tf.cast(input_image, tf.float32) / 255.0input_mask -= 1return input_image, input_mask# 加载训练集
def load_image_train(datapoint):input_image = datapoint["image"]input_mask = datapoint["segmentation_mask"]input_image, input_mask = resize(input_image, input_mask)input_image, input_mask = augment(input_image, input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask# 加载测试集
def load_image_test(datapoint):input_image = datapoint["image"]input_mask = datapoint["segmentation_mask"]input_image, input_mask = resize(input_image, input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask# 创建mask
def create_mask(pred_mask):pred_mask = tf.argmax(pred_mask, axis=-1)pred_mask = pred_mask[..., tf.newaxis]return pred_mask[0]# 显示预测结果
def show_predictions(dataset=None, num=1, unet_model=None):if dataset:for image, mask in dataset.take(num):pred_mask = unet_model.predict(image)display([image[0], mask[0], create_mask(pred_mask)])else:display([sample_image, sample_mask, create_mask(unet_model.predict(sample_image[tf.newaxis, ...]))])# 可视化
def display(display_list):plt.figure(figsize=(15, 15))title = ["Input Image", "True Mask", "Predicted Mask"]for i in range(len(display_list)):plt.subplot(1, len(display_list), i+1)plt.title(title[i])plt.imshow(tf.keras.utils.array_to_img(display_list[i]))plt.axis("off")plt.show()# sample_batch = next(iter(train_batches))
# random_index = np.random.choice(sample_batch[0].shape[0])
# sample_image, sample_mask = sample_batch[0][random_index], sample_batch[1][random_index]
# display([sample_image, sample_mask])

4、调用模型进行测试

加载训练好的模型,调用上面的函数,可以进行测试,测试结果如下图

model = load_model('unet.h5')
show_predictions(test_batches, 1, model)

四、其他参考

机器学习笔记 - Keras + TensorFlow2.0 + Unet进行语义分割_bashendixie5的博客-CSDN博客https://blog.csdn.net/bashendixie5/article/details/115795171

机器学习笔记 - 使用Keras + Unet 进行图像分割相关推荐

  1. 机器学习笔记: Upsampling, U-Net, Pyramid Scene Parsing Net

    前言 在CNN-based 的 模型中,我们可能会用到downsampling 操作来减少模型参数,以及扩大感受野的效果. 下图是一个graph segmentation的例子,就先使用 downsa ...

  2. 机器学习笔记 - 使用Keras Tuner进行自动化超参数调整

    一.什么是超参数调整? 在机器学习工作流程中,您已经根据对数据集的先验分析为模型选择或提取了特征和目标 - 可能使用了PCA等降维技术. 训练机器学习模型时会进行如下迭代: 在训练开始之前,以随机或几 ...

  3. 机器学习笔记 - 使用Keras和深度学习进行乳腺癌分类

    一.数据集简介 乳腺组织病理学图像 浸润性导管癌 (IDC) 是所有乳腺癌中最常见的亚型. 为了给整个样本分配侵袭性等级,病理学家通常关注包含 IDC 的区域. 因此,自动侵略性分级的常见预处理步骤之 ...

  4. 机器学习笔记:auto encoder

    1 autoencoder 介绍 这是一个无监督学习问题,旨在从原始数据x中学习一个低维的特征向量(没有任何标签) encoder 最早是用线性函数+非线性单元构成(比如Linear+nonlinea ...

  5. 李弘毅机器学习笔记:第十一章—Keras Demo

    李弘毅机器学习笔记:第十一章-Keras Demo 创建网络 配置 选择最好的方程 使用模型 创建网络 假设我们要做的事情是手写数字辨识,那我们要建一个Network scratch,input是28 ...

  6. Python机器学习笔记:使用Keras进行回归预测

    Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...

  7. 医学图像笔记(四)医学图像分割

    医学图像笔记(四)医学图像分割 1.医学图像分割的开源工具 2.其他分割 2.1.3D VNet 2.2.PE-VNet 3.医学图像数据集 3.1.百度AI studio 数据集 3.2.Githu ...

  8. 使用U-Net 进行图像分割

    最近做病理AI的细胞计数问题,需要对图像中的各个细胞进行分类,若采用普通的CNN+普通图像分割,估计实现效果不佳.为了解决这个问题,大致有两种方案:目标检测 和 图像分割.目标检测的算法以Faster ...

  9. 【学习打卡03】可解释机器学习笔记之CAM类激活热力图

    可解释机器学习笔记之CAM类激活热力图 文章目录 可解释机器学习笔记之CAM类激活热力图 CAM介绍 CAM算法原理 GAP全局平均池化 GAP VS GMP CAM算法的缺点及改进 CAM可视化 同 ...

  10. 李弘毅机器学习笔记:第十三章—CNN

    李弘毅机器学习笔记:第十三章-CNN 为什么用CNN Small region Same Patterns Subsampling CNN架构 Convolution Propetry1 Propet ...

最新文章

  1. java 数据库数据脱敏_Sharding-JDBC-数据脱敏
  2. DataGrid中添加背景
  3. 使用selenium进行密码破解(绕过账号密码JS加密)
  4. python pandas dataframe 排序,如何按两列或更多列对python pandas中的dataFrame进行排序?...
  5. 网上科学计算机,【图片】计算机-科普—都是从网上找的【计算机科学与技术吧】_百度贴吧...
  6. Zookeeper集群详解
  7. java留言板功能齐全源码_各类Java微信开发框架源码对比(建议收藏)
  8. ubuntu 安装 swift 64位
  9. bat处理中的管道[|]
  10. [转载] python怎么将十进制转换为二进制_python十进制和二进制的转换方法(含浮点数)
  11. 谜题40:不情愿的构造器
  12. 这就是搜索引擎--读书笔记四--索引基础
  13. 华为笔试题大全(史上最齐全)
  14. 豆瓣fm android,豆瓣 FM
  15. C#+Halcon调用Basler相机
  16. def demo什么意思python_你知道Python的所有入门级知识吗?,这些,都,会,了
  17. mysql common是什么_MySQL概述及入门(一)
  18. 重看经典动漫《火影忍者》的一些感受
  19. 通过 ICMP 协议实现 Ping Tunnel 建立可穿透网络隧道
  20. 关键链项目管理汇总贴

热门文章

  1. Unity3D利用代码生成脚本模板
  2. oracle数据库服务器名称修改,oracle数据库服务器名称修改
  3. 国内稳定的暗黑2服务器,国内暗黑2战网的基本概念介绍
  4. Android 高德地图No implementation found for long com.autonavi.amap.mapcore.MapCore
  5. unity shader可视化工具——Shader Graph
  6. linux手机刷机包制作工具_安卓10刷机包
  7. 怎么把照片做成计算机主题,windows10主题制作怎么操作_windows10电脑主题如何自己制作...
  8. 怎样用计算机制作思维导图,如何使用电脑制作成思维导图,这个方法简单又实在...
  9. SpeedFan 控制风扇转速
  10. 批处理(BAT)教程