Unet是一种U型网络,分为左右两部分卷积,左边为下采样提取高维特征,右边为上采样并与左侧融合实现图像分割。这里使用TensorFlow实现Unet网络,实现对遥感影像的道路分割。

训练数据:

标签图像:

Unet实现:

import tensorflow as tf
import numpy as np
import cv2
import glob
import itertoolsclass UNet:def __init__(self,input_width,input_height,num_classes,train_images,train_instances,val_images,val_instances,epochs,lr,lr_decay,batch_size,save_path):self.input_width = input_widthself.input_height = input_heightself.num_classes = num_classesself.train_images = train_imagesself.train_instances = train_instancesself.val_images = val_imagesself.val_instances = val_instancesself.epochs = epochsself.lr = lrself.lr_decay = lr_decayself.batch_size = batch_sizeself.save_path = save_pathdef leftNetwork(self, inputs):x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(inputs)o_1 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2, 2))(o_1)x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)o_2 = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_2)x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)o_3 = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_3)x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)o_4 = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_4)x = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)o_5 = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)return [o_1, o_2, o_3, o_4, o_5]def rightNetwork(self, inputs):c_1, c_2, c_3, c_4, o_5 = inputso_5 = tf.keras.layers.UpSampling2D((2, 2))(o_5)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(4)(c_4), o_5], axis=3)x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.UpSampling2D((2, 2))(x)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(16)(c_3), x], axis=3)x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.UpSampling2D((2, 2))(x)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(40)(c_2), x], axis=3)x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.UpSampling2D((2, 2))(x)x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(88)(c_1), x], axis=3)x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)x = tf.keras.layers.Conv2D(self.num_classes, (1, 1), padding='valid')(x)x = tf.keras.layers.Activation('softmax')(x)return xdef build_model(self):inputs = tf.keras.Input(shape=[self.input_height, self.input_width, 3])left_output = self.leftNetwork(inputs)right_output = self.rightNetwork(left_output)model = tf.keras.Model(inputs=inputs, outputs=right_output)return modeldef train(self):G_train = self.dataGenerator(model='training')G_eval = self.dataGenerator(model='validation')#model = self.build_model()model = tf.keras.models.load_model('model.h5')model.compile(optimizer=tf.keras.optimizers.Adam(self.lr, self.lr_decay),loss='categorical_crossentropy',metrics=['accuracy'])model.fit_generator(G_train, 5, validation_data=G_eval, validation_steps=5, epochs=self.epochs)model.save(self.save_path)def dataGenerator(self, model):if model == 'training':images = glob.glob(self.train_images + '*.jpg')images.sort()instances = glob.glob(self.train_instances + '*.png')instances.sort()zipped = itertools.cycle(zip(images, instances))while True:x_train = []y_train = []for _ in range(self.batch_size):img, seg = next(zipped)img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height)) / 255.0seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), self.num_classes)x_train.append(img)y_train.append(seg)yield np.array(x_train), np.array(y_train)if model == 'validation':images = glob.glob(self.val_images + '*.jpg')images.sort()instances = glob.glob(self.val_instances + '*.png')instances.sort()zipped = itertools.cycle(zip(images, instances))while True:x_eval = []y_eval = []for _ in range(self.batch_size):img, seg = next(zipped)img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height)) / 255.0seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), self.num_classes)x_eval.append(img)y_eval.append(seg)yield np.array(x_eval), np.array(y_eval)

训练脚本:

unet = UNet(input_width=572,input_height=572,num_classes=2,train_images='./datasets/train/images/',train_instances='./datasets/train/instances/',val_images='./datasets/validation/images/',val_instances='./datasets/validation/instances/',epochs=100,lr=0.0001,lr_decay=0.00001,batch_size=100,save_path='model.h5'
)unet.train()

这里仅分割道路和背景,属于二分类,输出矩阵形状为2*388*388,进行100轮训练后保存模型进行推理验证。

推理脚本:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2model = tf.keras.models.load_model('model.h5')img = '17.jpg'
img = cv2.resize(cv2.imread(img), (572, 572)) / 255.
img = np.expand_dims(img, 0)
pred = model.predict(img)
pred = np.argmax(pred[0], axis=-1)
pred[pred == 1] = 255
cv2.imwrite('result.jpg', pred)
plt.imshow(pred)
plt.show()

测试图像:

推理结果:

将推理结果与原始图像叠加显示:

import cv2img_path = '17.jpg'
result_path = 'result.jpg'
img = cv2.imread(img_path)
height, width = img.shape[:2]
result = cv2.imread(result_path)
result = cv2.resize(result, (height, width), cv2.INTER_LINEAR)
result = cv2.Canny(result, 0, 255)
for i in range(height):for j in range(width):if result[i][j] == 255:img[i][j] = [0, 0, 255]
cv2.imwrite('temp.jpg', result)
cv2.imwrite('out.jpg', img)

TensorFlow实现Unet遥感图像分割相关推荐

  1. 基于SegNet和UNet的遥感图像分割代码解读

    基于SegNet和UNet的遥感图像分割代码解读 目录 基于SegNet和UNet的遥感图像分割代码解读 前言 概述 代码框架 代码细节分析 划分数据集gen_dataset.py UNet模型训练u ...

  2. 毕业设计 U-Net遥感图像语义分割(源码+论文)

    文章目录 0 项目说明 1 研究目的 2 研究方法 3 研究结论 4 论文目录 5 项目源码 6 最后 0 项目说明 **基于 U-Net 网络的遥感图像语义分割 ** 提示:适合用于课程设计或毕业设 ...

  3. 用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割

    用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割 Accelerating Medical Image Segmentation with NVIDIA Tensor ...

  4. 使用U-Net 进行图像分割

    最近做病理AI的细胞计数问题,需要对图像中的各个细胞进行分类,若采用普通的CNN+普通图像分割,估计实现效果不佳.为了解决这个问题,大致有两种方案:目标检测 和 图像分割.目标检测的算法以Faster ...

  5. Python基于改进FCN&VGG的高分辨率遥感图像分割(完整源码&数据集&视频教程)

    1.高分辨率遥感图像分割效果展示: 2.数据集简介: 首先介绍一下数据,我们这次采用的数据集是CCF大数据比赛提供的数据(2015年中国南方某城市的高清遥感图像),这是一个小数据集,里面包含了5张带标 ...

  6. 遥感-Deep Covariance Alignment for Domain Adaptive Remote Sensing Image Segmentation域自适应遥感图像分割中深度协方差对齐

    Deep Covariance Alignment for Domain Adaptive Remote Sensing Image Segmentation域自适应遥感图像分割中的深度协方差对齐 0 ...

  7. OpenCV C++案例实战二十九《遥感图像分割》

    OpenCV C++案例实战二十九<遥感图像分割> 前言 一.准备数据 二.K-Means分类 三.效果显示 四.源码 总结 前言 本案例基于k-means机器学习算法进行遥感图像分割.主 ...

  8. 【Matlab/CV系列】基于K-means/分水岭分割的多光谱遥感图像分割的Matlab实现

    Date:2022.4.18 文章目录 前言 1.初始界面 2.三种方法分割界面 3.光谱图 前言 在之前的时候,毕业设计中实现了基于K-means/分水岭/交叉熵分割的多光谱遥感图像分割算法,效果不 ...

  9. nnU-Net: 基于U-Net医学图像分割技术的自适应框架

    ** nnU-Net: 基于U-Net医学图像分割技术的自适应框架 ** https://arxiv.org/pdf/1809.10486.pdf 作者:Fabian Isensee 提要 U-Net ...

最新文章

  1. ImportError:cannot import name ‘display‘ File “XX“, line 5, in <module> from IPython import display
  2. 公司终于决定放弃微服务传统设计模式,全面拥抱 DDD!
  3. 利用jdom生成XML文件
  4. php编写TCP服务端和客户端程序
  5. C++ 预编译头文件
  6. Anaconda安装库
  7. 苹果罕见人事大调整:多个项目被迫暂停 员工“惊慌失措”
  8. Java中SimpleDateFormat用法详解
  9. linux计划任务与日志管理(日志分割/切割)
  10. 关于影响NodeManager执行MR任务constainer数量的设置问题
  11. 深入探索 Android 包体积优化(匠心制作)
  12. 【git】从零开始在git上部署自己的免费生日祝福网页
  13. linux 下载工具
  14. 工业控制系统漏洞检测技术(工控安全学习笔记)
  15. AARRR模型——激活:获客红海背后的蓝海(上)
  16. vue的五个小实例解析其基础功能
  17. 搜索引擎的目标是什么?
  18. 信号系统笔记(二)连续系统的时域分析
  19. java实现图片反色
  20. MTK 虚拟 sensor bring up (pick up) sensor1.0

热门文章

  1. Codeup墓地-问题 A: 还是畅通工程
  2. TreeMap方法源码
  3. 数据结构题:根据所给权值设计相应的哈夫曼树,并设计哈夫曼编码
  4. php module类,总结php artisan module常用命令
  5. mysql 5.6.23 源码包安装报错_CentOS6.5_64bit下编译安装MySQL-5.6.23
  6. 计算机本地用户删除后怎么恢复,电脑本地磁盘盘符被隐藏C盘不见了恢复方法...
  7. java不使用除号实现除法运算_LeetCode29 Medium 不用除号实现快速除法
  8. mac回退jdk版本_mac中不同jdk版本切换
  9. android7.1 shotcuts,Android N App Shotcuts 学习
  10. 单元格自适应宽度_Excel如何对表格进行自适应设置,方法很简单