本文包含制作数据集、训练、推理、测试图像及结果四部分内容

目录

制作数据集

训练

推理

测试图像及结果

制作数据集

该数据集包含420张224*400图像,图像由画图工具产生,包含圆形、矩形和背景三种类别,选用不同的颜色进行填充。部分训练图像和标签图像如下图所示

训练

根据所填充颜色,将每张标注图像生成为rows*cols*class_nums的形式

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img, save_img
import numpy as np
import oscolorDict = [[0, 0, 0], [34, 177, 76], [237, 28, 36]] ###背景、圆形、矩形的填充色
colorDict_RGB = np.array(colorDict)
colorDict_GRAY = colorDict_RGB[:, 0]
num_classes = 3def data_preprocess(label, class_num):for i in range(colorDict_GRAY.shape[0]):label[label == colorDict_GRAY[i]] = inew_label = np.zeros(label.shape + (class_num,))for i in range(class_num):new_label[label == i, i] = 1label = new_labelreturn labeldef visual(array):for j in range(num_classes):vis = array[:, :, j]vis = vis*255vis = vis.reshape(224, 400, 1)vis_out = array_to_img(vis)vis_out.show()class dataProcess(object):def __init__(self, out_rows, out_cols, data_path="../train1", label_path="../label1",test_path="../test1", npy_path="../npydata", img_type="bmp"):# 数据处理类,初始化self.out_rows = out_rowsself.out_cols = out_colsself.data_path = data_pathself.label_path = label_pathself.img_type = img_typeself.test_path = test_pathself.npy_path = npy_pathself.num_classes = num_classes# 创建训练数据def create_train_data(self):print('-' * 30)print('Creating training images...')print('-' * 30)img_list = os.listdir(self.data_path)imgdatas = np.ndarray((len(img_list), self.out_rows, self.out_cols, 1), dtype=np.uint8)imglabels = np.ndarray((len(img_list), self.out_rows, self.out_cols, self.num_classes), dtype=np.uint8)for i in range(len(img_list)):img = load_img(self.data_path + "/" + img_list[i], color_mode="grayscale")img = img_to_array(img)imgdatas[i] = imglabel = load_img(self.label_path + "/" + img_list[i])label = img_to_array(label)[:, :, 0]label = data_preprocess(label, num_classes)# visual(label)imglabels[i] = labelnp.save(self.npy_path + '/imgs_train.npy', imgdatas)np.save(self.npy_path + '/imgs_mask_train.npy', imglabels)print('Saving to .npy files done.')def load_train_data(self):print('-' * 30)print('load train images...')print('-' * 30)imgs_train = np.load(self.npy_path + "/imgs_train.npy")imgs_mask_train = np.load(self.npy_path + "/imgs_mask_train.npy")imgs_train = imgs_train.astype('float32')imgs_mask_train = imgs_mask_train.astype('float32')imgs_train /= 255.0imgs_mask_train /= 255.0return imgs_train, imgs_mask_traindef create_test_data(self):test_list = []print('-' * 30)print('Creating test images...')print('-' * 30)img_list = os.listdir(self.test_path)testdatas = np.ndarray((len(img_list), self.out_rows, self.out_cols, 1), dtype=np.uint8)for i in range(len(img_list)):img = load_img(self.test_path + "/" + img_list[i], color_mode="grayscale")img = img_to_array(img)testdatas[i] = imgtest_list.append(img_list[i])np.save(self.npy_path + '/imgs_test.npy', testdatas)print('Saving to .npy files done.')return test_listdef load_test_data(self):print('-' * 30)print('load test images...')print('-' * 30)imgs_test = np.load(self.npy_path + "/imgs_test.npy")imgs_test = imgs_test.astype('float32')imgs_test /= 255.0return imgs_testif __name__ == "__main__":mydata = dataProcess(224, 400)mydata.create_train_data()imgs_train, imgs_mask_train = mydata.load_train_data()print(imgs_train.shape, imgs_mask_train.shape)

将conv10中的类别数目修改为class_nums,将激活函数修改为softmax,将loss函数修改为'categorical_crossentropy'

import numpy as np
from keras.models import *
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint
from keras import backend as keras
from my_test.data import *
from keras.models import Modelclass myUnet(object):def __init__(self, img_rows = 224, img_cols = 400):self.img_rows = img_rowsself.img_cols = img_colsdef load_data(self):mydata = dataProcess(self.img_rows, self.img_cols)imgs_train, imgs_mask_train = mydata.load_train_data()return imgs_train, imgs_mask_traindef get_unet(self):inputs = Input((self.img_rows, self.img_cols, 1))# 网络结构定义'''#unet with crop(because padding = valid) conv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(inputs)print "conv1 shape:",conv1.shapeconv1 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv1)print "conv1 shape:",conv1.shapecrop1 = Cropping2D(cropping=((90,90),(90,90)))(conv1)print "crop1 shape:",crop1.shapepool1 = MaxPooling2D(pool_size=(2, 2))(conv1)print "pool1 shape:",pool1.shapeconv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool1)print "conv2 shape:",conv2.shapeconv2 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv2)print "conv2 shape:",conv2.shapecrop2 = Cropping2D(cropping=((41,41),(41,41)))(conv2)print "crop2 shape:",crop2.shapepool2 = MaxPooling2D(pool_size=(2, 2))(conv2)print "pool2 shape:",pool2.shapeconv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool2)print "conv3 shape:",conv3.shapeconv3 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv3)print "conv3 shape:",conv3.shapecrop3 = Cropping2D(cropping=((16,17),(16,17)))(conv3)print "crop3 shape:",crop3.shapepool3 = MaxPooling2D(pool_size=(2, 2))(conv3)print "pool3 shape:",pool3.shapeconv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool3)conv4 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv4)drop4 = Dropout(0.5)(conv4)crop4 = Cropping2D(cropping=((4,4),(4,4)))(drop4)pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(pool4)conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv5)drop5 = Dropout(0.5)(conv5)up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))merge6 = merge([crop4,up6], mode = 'concat', concat_axis = 3)conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge6)conv6 = Conv2D(512, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv6)up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))merge7 = merge([crop3,up7], mode = 'concat', concat_axis = 3)conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge7)conv7 = Conv2D(256, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv7)up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))merge8 = merge([crop2,up8], mode = 'concat', concat_axis = 3)conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge8)conv8 = Conv2D(128, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv8)up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))merge9 = merge([crop1,up9], mode = 'concat', concat_axis = 3)conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(merge9)conv9 = Conv2D(64, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)conv9 = Conv2D(2, 3, activation = 'relu', padding = 'valid', kernel_initializer = 'he_normal')(conv9)'''conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)pool1 = MaxPooling2D((2, 2))(conv1)conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)pool2 = MaxPooling2D((2, 2))(conv2)conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)drop4 = Dropout(0.5)(conv4)pool4 = MaxPooling2D((2, 2))(drop4)conv5 = Conv2D(1024, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)conv5 = Conv2D(1024, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)drop5 = Dropout(0.5)(conv5)up6 = Conv2D(512, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))merge6 = concatenate([drop4, up6], axis = 3)conv6 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)conv6 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)up7 = Conv2D(256, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))merge7 = concatenate([conv3, up7], axis = 3)conv7 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)conv7 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)up8 = Conv2D(128, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))merge8 = concatenate([conv2, up8], axis = 3)conv8 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)conv8 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)up9 = Conv2D(64, (2,2), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))merge9 = concatenate([conv1, up9], axis = 3)conv9 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)conv9 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)conv9 = Conv2D(2, (3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)# conv10 = Conv2D(1, (1,1), activation = 'sigmoid')(conv9)conv10 = Conv2D(class_nums, (1,1), activation = 'softmax')(conv9)model = Model(inputs = inputs, outputs = conv10)model.summary()# model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])model.compile(optimizer = Adam(lr = 1e-4), loss = 'categorical_crossentropy', metrics = ['accuracy'])return modeldef train(self):print("loading data")imgs_train, imgs_mask_train = self.load_data()print("loading data done")model = self.get_unet()print("got unet")model_checkpoint = ModelCheckpoint('my_unet.hdf5', monitor='loss',verbose=1, save_best_only=True)print('Fitting model...')model.fit(imgs_train, imgs_mask_train, batch_size=4, epochs=10, verbose=1, validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])if __name__ == '__main__':class_nums = 3myunet = myUnet()myunet.train()

推理

将推理结果拆分为3通道图像,分别显示各通道图像

from my_test.data import *
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import array_to_imgdef save_img(test_list):print("array to image")imgs = np.load('../11/imgs_mask_test.npy')for i in range(imgs.shape[0]):img = imgs[i]for j in range(class_num):out = img[:, :, j]out = out.reshape(224, 400, 1)out = array_to_img(out)out.save("../11/" + str(j) + '_' + test_list[i])unet_model_path = 'my_unet.hdf5'
model = load_model(unet_model_path)
class_num = 3
mydata = dataProcess(224, 400)
imgs_test = mydata.load_test_data()
test_list = mydata.create_test_data()
imgs_mask_test = model.predict(imgs_test, batch_size=1, verbose=1)
np.save('../11/imgs_mask_test.npy', imgs_mask_test)
save_img(test_list)

测试图像及结果

网络输入为rows*cols*1,输出为rows*cols*class_nums。在数据处理阶段,将通道0中的背景设为mask区域,通道1中的圆形设置为mask区域,通道2中的矩形设置为mask区域,因此对输出的三个通道进行拆分得到:通道0为背景的分割结果,通道1为圆形的分割结果,通道2为矩形的分割结果

UNet多类别分割的keras实现相关推荐

  1. 教程: UNet/UNet++多类别图像分割,含数据集制作

    向AI转型的程序员都关注了这个号???????????? 人工智能大数据与深度学习  公众号:datayx 运行demo 下载数据集 https://pan.baidu.com/s/1PK3VoarN ...

  2. 深度学习-Tensorflow2.2-图像处理{10}-UNET图像语义分割模型-24

    UNET图像语义分割模型简介 代码 import tensorflow as tf import matplotlib.pyplot as plt %matplotlib inline import ...

  3. 视频教程-U-Net图像语义分割实战:训练自己的数据集-计算机视觉

    U-Net图像语义分割实战:训练自己的数据集 大学教授,美国归国博士.博士生导师:人工智能公司专家顾问:长期从事人工智能.物联网.大数据研究:已发表学术论文100多篇,授权发明专利10多项 白勇 ¥8 ...

  4. 亚利桑那州立大学周纵苇:研习 U-Net ——现有的分割网络创新 | AI 研习社74期大讲堂...

    雷锋网AI研习社按:经典的 Encoder-Decoder 结构在目标分割问题中展现出了举足轻重的作用,然而这样一个相对固定的框架使得模型在感受野大小和边界分割精度两方面很难达到兼顾.本次公开课,讲者 ...

  5. Unet用于人像分割

    论文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation 主要内容 人像分割简介UNet的简介UNet实现人像分割 人像 ...

  6. Unet美发实例分割,染发展示

    向AI转型的程序员都关注了这个号???????????? 人工智能大数据与深度学习  公众号:datayx 美发 分割--获取头发的所在区域. 染发--保持亮度,转换颜色. 代码和模型 获取方式: 分 ...

  7. u-net语义分割_使用U-Net的语义分割

    u-net语义分割 Picture By Martei Macru On Unsplash 图片由Martei Macru On Unsplash拍摄 Semantic segmentation is ...

  8. Unet实现细胞分割

    目的:实现细胞分割 数据集:isbi挑战赛的数据,只有30张512x512的image和label 思路: 读取数据,将数据转换成 30x512x512x1格式: 由于数据太少,所以进行数据增强:(注 ...

  9. [论文笔记]彻底讲透U-net医学影像分割-小样本

    U-net原文 <2015_Ronneberger_Cite=49316_U-net: Convolutional networks for biomedical image segmentat ...

  10. PyTorch12—Unet图像语义分割

    语义分割简介 图像语义分割是计算机视觉中十分重要的领域.它是指像素级地识别图像,即标注出图像中每个像素所属的对象类别.下图为语义分割的一个实例,其目标是预测出图像中每一个像素的类标签. 图像语义分割是 ...

最新文章

  1. 在k8s中使用gradle构建java web项目镜像Dockerfile
  2. 用 Python 和 OpenCV 来测量相机到目标的距离
  3. Cookie 和 Session的区别
  4. 浅谈ios设计之使用表格UITableVIew设计通讯录的方法
  5. selenium之截图
  6. 删除iptables nat 规则
  7. Hibernate笔记7--JPA CRUD
  8. linux查看根目录的大小,linux下查看根目录或当前目录大小
  9. 关于解压软件和压缩软件
  10. 莫烦Python代码实践(四)——DQN基础算法工程化解析
  11. 机载激光雷达原理与应用科普(六)
  12. Elasticsearch之近义词/同义词的使用
  13. 用Python讲述:地理“经纬度”数据的4种转换方法!
  14. kubesphere+kubernetes搭建生产环境高可用集群(一)
  15. ISA SERVER常见问题总结专用贴(转)
  16. scrapy框架讲解
  17. 《途客圈创业记:不疯魔,不成活》一一1.5 依依辞别Juniper
  18. magic winmail邮件服务器,使用Magic Winmail Server轻松架设邮件服务器(一)-网管专栏,邮件服务...
  19. fgo怎么把new去掉_fgo命运冠位指定国服NEWYORK2020无限池活动介绍
  20. 新手必学的20个人像摄影构图法

热门文章

  1. 通信原理几种调制方式
  2. VISSIM二次开发(Python)大作业总结1
  3. 高斯09linux教程,Gaussian 09的安装与使用
  4. 视频监控存储解决方案——快速、可靠的视频存储
  5. 《EDA前端软件开发工程师面试指南》
  6. masm5安装教程_MASM_6.11安装方法
  7. oracle根据中文获取拼音全拼函数
  8. UCINET(64 bits)使用中 Access violation at address问题
  9. vba调JAVA并接收返回值_如何从VBA函数返回结果
  10. python爬虫淘宝视频_python爬虫视频教程:一篇文章教会你用Python爬取淘宝评论数据...