机器学习笔记 - 使用Keras + Unet 进行图像分割
一、U-Net简介
U-Net 是最初为医学影像分割而提出的一种语义分割技术。 它是较早的深度学习分割模型之一,U-Net 架构也用于许多 GAN 变体,例如 Pix2Pix 生成器。
U-Net 在论文 U-Net: Convolutional Networks for Biomedical Image Segmentation 中进行了介绍。 模型架构相当简单:一个编码器(用于下采样)和一个解码器(用于上采样),带有跳跃连接。 如图 1 所示,它的形状像字母 U,因此得名 U-Net。
二、数据集说明
我们将使用作为 TensorFlow 数据集 (TFDS) 的一部分提供的 Oxford-IIIT 宠物数据集。 它可以很容易地用 TFDS 加载,然后进行一些数据预处理,为训练分割模型做好准备。
可以使用 tfds 通过指定数据集的名称来加载数据集,并通过设置 with_info=True 来获取数据集信息:
代码如下,如果多次运行程序,第一次下载完之后可以添加download=False参数,会自动从已经下载好的文件夹下读取数据。
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
使用 print(info) 打印数据集信息,我们将看到牛津宠物数据集的各种详细信息。 例如,在图 2 中,我们可以看到共有 7349 张图像,其中包含内置的测试/训练拆分。
三、相关代码
1、unet模型
U-Net 的架构相当简单; 然而,为了在编码器和解码器之间创建跳跃连接,我们需要连接一些层。 所以 Keras 函数式 API 最适合这个目的。
首先,我们创建一个 build_unet_model 函数,指定输入、编码器层、瓶颈、解码器层,最后是带有激活 softmax 的 Conv2D 的输出层。 注意输入图像的形状是 128x128x3。 输出具有三个通道,对应于模型将为每个像素分类的三个类:背景、前景对象和对象轮廓。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np# 在编码器和U-Net的瓶颈中使用
def double_conv_block(x, n_filters):# Conv2D then ReLU activationx = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)# Conv2D then ReLU activationx = layers.Conv2D(n_filters, 3, padding = "same", activation = "relu", kernel_initializer = "he_normal")(x)return x# 用于在编码器中进行下采样或特征提取
def downsample_block(x, n_filters):f = double_conv_block(x, n_filters)p = layers.MaxPool2D(2)(f)p = layers.Dropout(0.3)(p)return f, p# 上采样函数 upsample_block
def upsample_block(x, conv_features, n_filters):# upsamplex = layers.Conv2DTranspose(n_filters, 3, 2, padding="same")(x)# concatenatex = layers.concatenate([x, conv_features])# dropoutx = layers.Dropout(0.3)(x)# Conv2D twice with ReLU activationx = double_conv_block(x, n_filters)return x# 创建模型
def build_unet_model():# inputsinputs = layers.Input(shape=(128, 128, 3))# encoder: contracting path - downsample# 1 - downsamplef1, p1 = downsample_block(inputs, 64)# 2 - downsamplef2, p2 = downsample_block(p1, 128)# 3 - downsamplef3, p3 = downsample_block(p2, 256)# 4 - downsamplef4, p4 = downsample_block(p3, 512)# 5 - bottleneckbottleneck = double_conv_block(p4, 1024)# decoder: expanding path - upsample# 6 - upsampleu6 = upsample_block(bottleneck, f4, 512)# 7 - upsampleu7 = upsample_block(u6, f3, 256)# 8 - upsampleu8 = upsample_block(u7, f2, 128)# 9 - upsampleu9 = upsample_block(u8, f1, 64)# outputsoutputs = layers.Conv2D(3, 1, padding="same", activation="softmax")(u9)# unet model with Keras Functional APIunet_model = tf.keras.Model(inputs, outputs, name="U-Net")return unet_model
2、训练代码
运行train函数进行训练。
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True, download=False)
train_dataset = dataset["train"].map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = dataset["test"].map(load_image_test, num_parallel_calls=tf.data.AUTOTUNE)BATCH_SIZE = 32
BUFFER_SIZE = 1000
train_batches = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_batches = train_batches.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validation_batches = test_dataset.take(3000).batch(BATCH_SIZE)
test_batches = test_dataset.skip(3000).take(669).batch(BATCH_SIZE)def train():unet_model = build_unet_model()unet_model.compile(optimizer=tf.keras.optimizers.Adam(),loss="sparse_categorical_crossentropy",metrics="accuracy")NUM_EPOCHS = 40TRAIN_LENGTH = info.splits["train"].num_examplesSTEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZEVAL_SUBSPLITS = 5TEST_LENTH = info.splits["test"].num_examplesVALIDATION_STEPS = TEST_LENTH // BATCH_SIZE // VAL_SUBSPLITSmodel_history = unet_model.fit(train_batches,epochs=NUM_EPOCHS,steps_per_epoch=STEPS_PER_EPOCH,validation_steps=VALIDATION_STEPS,validation_data=test_batches)unet_model.save('unet.h5')
3、其它函数
# 修改大小
def resize(input_image, input_mask):input_image = tf.image.resize(input_image, (128, 128), method="nearest")input_mask = tf.image.resize(input_mask, (128, 128), method="nearest")return input_image, input_mask# 水平翻转
def augment(input_image, input_mask):if tf.random.uniform(()) > 0.5:# Random flipping of the image and maskinput_image = tf.image.flip_left_right(input_image)input_mask = tf.image.flip_left_right(input_mask)return input_image, input_mask# 规范化数据集
def normalize(input_image, input_mask):input_image = tf.cast(input_image, tf.float32) / 255.0input_mask -= 1return input_image, input_mask# 加载训练集
def load_image_train(datapoint):input_image = datapoint["image"]input_mask = datapoint["segmentation_mask"]input_image, input_mask = resize(input_image, input_mask)input_image, input_mask = augment(input_image, input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask# 加载测试集
def load_image_test(datapoint):input_image = datapoint["image"]input_mask = datapoint["segmentation_mask"]input_image, input_mask = resize(input_image, input_mask)input_image, input_mask = normalize(input_image, input_mask)return input_image, input_mask# 创建mask
def create_mask(pred_mask):pred_mask = tf.argmax(pred_mask, axis=-1)pred_mask = pred_mask[..., tf.newaxis]return pred_mask[0]# 显示预测结果
def show_predictions(dataset=None, num=1, unet_model=None):if dataset:for image, mask in dataset.take(num):pred_mask = unet_model.predict(image)display([image[0], mask[0], create_mask(pred_mask)])else:display([sample_image, sample_mask, create_mask(unet_model.predict(sample_image[tf.newaxis, ...]))])# 可视化
def display(display_list):plt.figure(figsize=(15, 15))title = ["Input Image", "True Mask", "Predicted Mask"]for i in range(len(display_list)):plt.subplot(1, len(display_list), i+1)plt.title(title[i])plt.imshow(tf.keras.utils.array_to_img(display_list[i]))plt.axis("off")plt.show()# sample_batch = next(iter(train_batches))
# random_index = np.random.choice(sample_batch[0].shape[0])
# sample_image, sample_mask = sample_batch[0][random_index], sample_batch[1][random_index]
# display([sample_image, sample_mask])
4、调用模型进行测试
加载训练好的模型,调用上面的函数,可以进行测试,测试结果如下图
model = load_model('unet.h5')
show_predictions(test_batches, 1, model)
四、其他参考
机器学习笔记 - Keras + TensorFlow2.0 + Unet进行语义分割_bashendixie5的博客-CSDN博客https://blog.csdn.net/bashendixie5/article/details/115795171
机器学习笔记 - 使用Keras + Unet 进行图像分割相关推荐
- 机器学习笔记: Upsampling, U-Net, Pyramid Scene Parsing Net
前言 在CNN-based 的 模型中,我们可能会用到downsampling 操作来减少模型参数,以及扩大感受野的效果. 下图是一个graph segmentation的例子,就先使用 downsa ...
- 机器学习笔记 - 使用Keras Tuner进行自动化超参数调整
一.什么是超参数调整? 在机器学习工作流程中,您已经根据对数据集的先验分析为模型选择或提取了特征和目标 - 可能使用了PCA等降维技术. 训练机器学习模型时会进行如下迭代: 在训练开始之前,以随机或几 ...
- 机器学习笔记 - 使用Keras和深度学习进行乳腺癌分类
一.数据集简介 乳腺组织病理学图像 浸润性导管癌 (IDC) 是所有乳腺癌中最常见的亚型. 为了给整个样本分配侵袭性等级,病理学家通常关注包含 IDC 的区域. 因此,自动侵略性分级的常见预处理步骤之 ...
- 机器学习笔记:auto encoder
1 autoencoder 介绍 这是一个无监督学习问题,旨在从原始数据x中学习一个低维的特征向量(没有任何标签) encoder 最早是用线性函数+非线性单元构成(比如Linear+nonlinea ...
- 李弘毅机器学习笔记:第十一章—Keras Demo
李弘毅机器学习笔记:第十一章-Keras Demo 创建网络 配置 选择最好的方程 使用模型 创建网络 假设我们要做的事情是手写数字辨识,那我们要建一个Network scratch,input是28 ...
- Python机器学习笔记:使用Keras进行回归预测
Keras是一个深度学习库,包含高效的数字库Theano和TensorFlow.是一个高度模块化的神经网络库,支持CPU和GPU. 本文学习的目的是学习如何加载CSV文件并使其可供Keras使用,如何 ...
- 医学图像笔记(四)医学图像分割
医学图像笔记(四)医学图像分割 1.医学图像分割的开源工具 2.其他分割 2.1.3D VNet 2.2.PE-VNet 3.医学图像数据集 3.1.百度AI studio 数据集 3.2.Githu ...
- 使用U-Net 进行图像分割
最近做病理AI的细胞计数问题,需要对图像中的各个细胞进行分类,若采用普通的CNN+普通图像分割,估计实现效果不佳.为了解决这个问题,大致有两种方案:目标检测 和 图像分割.目标检测的算法以Faster ...
- 【学习打卡03】可解释机器学习笔记之CAM类激活热力图
可解释机器学习笔记之CAM类激活热力图 文章目录 可解释机器学习笔记之CAM类激活热力图 CAM介绍 CAM算法原理 GAP全局平均池化 GAP VS GMP CAM算法的缺点及改进 CAM可视化 同 ...
- 李弘毅机器学习笔记:第十三章—CNN
李弘毅机器学习笔记:第十三章-CNN 为什么用CNN Small region Same Patterns Subsampling CNN架构 Convolution Propetry1 Propet ...
最新文章
- java 数据库数据脱敏_Sharding-JDBC-数据脱敏
- DataGrid中添加背景
- 使用selenium进行密码破解(绕过账号密码JS加密)
- python pandas dataframe 排序,如何按两列或更多列对python pandas中的dataFrame进行排序?...
- 网上科学计算机,【图片】计算机-科普—都是从网上找的【计算机科学与技术吧】_百度贴吧...
- Zookeeper集群详解
- java留言板功能齐全源码_各类Java微信开发框架源码对比(建议收藏)
- ubuntu 安装 swift 64位
- bat处理中的管道[|]
- [转载] python怎么将十进制转换为二进制_python十进制和二进制的转换方法(含浮点数)
- 谜题40:不情愿的构造器
- 这就是搜索引擎--读书笔记四--索引基础
- 华为笔试题大全(史上最齐全)
- 豆瓣fm android,豆瓣 FM
- C#+Halcon调用Basler相机
- def demo什么意思python_你知道Python的所有入门级知识吗?,这些,都,会,了
- mysql common是什么_MySQL概述及入门(一)
- 重看经典动漫《火影忍者》的一些感受
- 通过 ICMP 协议实现 Ping Tunnel 建立可穿透网络隧道
- 关键链项目管理汇总贴
热门文章
- Unity3D利用代码生成脚本模板
- oracle数据库服务器名称修改,oracle数据库服务器名称修改
- 国内稳定的暗黑2服务器,国内暗黑2战网的基本概念介绍
- Android 高德地图No implementation found for long com.autonavi.amap.mapcore.MapCore
- unity shader可视化工具——Shader Graph
- linux手机刷机包制作工具_安卓10刷机包
- 怎么把照片做成计算机主题,windows10主题制作怎么操作_windows10电脑主题如何自己制作...
- 怎样用计算机制作思维导图,如何使用电脑制作成思维导图,这个方法简单又实在...
- SpeedFan 控制风扇转速
- 批处理(BAT)教程