tensorflow2.3实现街景语义分割

Cityscapes评测数据集即城市景观数据集,在2015年由奔驰公司推动发布,是目前公认的机器视觉领域内最具权威性和专业性的图像分割数据集之一。提供了8种30个类别的语义级别、实例级别以及密集像素标注(包括平坦表面、人、车辆、建筑、物体、自然、天空、空)。Cityscapes拥有5000张精细标注的在城市环境中驾驶场景的图像(2975train,500 val,1525test)。它具有19个类别的密集像素标注(97%coverage),其中8个具有实例级分割。数据是从50个城市中持续数月采集而来,涵盖不同的时间以及好的天气情况。开始起以视频形式存储,因此该数据集按照以下特点手动选出视频的帧:大量的动态物体,变化的场景布局以及变化的背景。

代码

导入包

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import glob

显存自适应分配,查看tensorflow 的版本

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
tf.__version__
  • ‘2.3.0’

读取数据,图像数据

images = glob.glob('./dataset/cityscapes/leftImg8bit/train/*/*.png')
print(len(img))img[:5]
  • 2975
  • [’./dataset/cityscapes/leftImg8bit/train/dusseldorf/dusseldorf_000128_000019_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/dusseldorf/dusseldorf_000113_000019_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/dusseldorf/dusseldorf_000014_000019_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/dusseldorf/dusseldorf_000207_000019_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/dusseldorf/dusseldorf_000216_000019_leftImg8bit.png’]

标签数据

label = glob.glob('./dataset/cityscapes/gtFine/train/*/*_gtFine_labelIds.png')
print(len(label))
label[:5]
  • 2975
  • [’./dataset/cityscapes/gtFine/train/dusseldorf/dusseldorf_000015_000019_gtFine_labelIds.png’,
    ‘./dataset/cityscapes/gtFine/train/dusseldorf/dusseldorf_000213_000019_gtFine_labelIds.png’,
    ‘./dataset/cityscapes/gtFine/train/dusseldorf/dusseldorf_000164_000019_gtFine_labelIds.png’,
    ‘./dataset/cityscapes/gtFine/train/dusseldorf/dusseldorf_000050_000019_gtFine_labelIds.png’,
    ‘./dataset/cityscapes/gtFine/train/dusseldorf/dusseldorf_000072_000019_gtFine_labelIds.png’]

为了把图像数据和标签数据是一一对应的,所以按照名称进行排序。

img.sort(key=lambda x: x.split('/')[-1].split('.png')[0])
label.sort(key=lambda x: x.split('/')[-1].split('.png')[0])

创建乱序的索引

index = np.random.permutation(len(img))

乱序后查看图像和标签数据

img = np.array(img)[index]
label = np.array(label)[index]

乱序后保持图像和标签还是一一对应的。

img[:5]
  • array([’./dataset/cityscapes/leftImg8bit/train/stuttgart/stuttgart_000195_000019_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/tubingen/tubingen_000047_000019_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/monchengladbach/monchengladbach_000000_019682_leftImg8bit.png’,
    ‘./dataset/cityscapes/leftImg8bit/train/dusseldorf/dusseldorf_000075_000019_leftImg8bit.png’,
    ./dataset/cityscapes/leftImg8bit/train/monchengladbach/monchengladbach_000000_010733_leftImg8bit.png’], dtype=’<U158’)
label[:5]
  • array([’./dataset/cityscapes/gtFine/train/stuttgart/stuttgart_000195_000019_gtFine_labelIds.png’,
    ‘./dataset/cityscapes/gtFine/train/tubingen/tubingen_000047_000019_gtFine_labelIds.png’, ‘./dataset/cityscapes/gtFine/train/monchengladbach/monchengladbach_000000_019682_gtFine_labelIds.png’, ./dataset/cityscapes/gtFine/train/dusseldorf/dusseldorf_000075_000019_gtFine_labelIds.png’, ‘./dataset/cityscapes/gtFine/train/monchengladbach/monchengladbach_000000_010733_gtFine_labelIds.png’], dtype=’<U157’)

创建测试集

img_val = glob.glob('./dataset/cityscapes/leftImg8bit/val/*/*.png')
label_val = glob.glob('./dataset/cityscapes/gtFine/val/*/*_gtFine_labelIds.png')
len(img_val), len(label_val)
  • (500, 500)

测试集的图形和标签按照名字排序

img_val.sort(key=lambda x: x.split('/')[-1].split('.png')[0])
label_val.sort(key=lambda x: x.split('/')[-1].split('.png')[0])

测试集的数量

val_count = len(img_val)
val_count
  • 500

训练集的数量

train_count = len(img)
train_count
  • 2975

构建训练集的dataset

dataset_train = tf.data.Dataset.from_tensor_slices((img, label))
dataset_train
  • <TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.string)>
    构建测试集的dataset
dataset_val = tf.data.Dataset.from_tensor_slices((img_val, label_val))
dataset_val
  • <TensorSliceDataset shapes: ((), ()), types: (tf.string, tf.string)>

封装加载图像数据解码函数

def read_png(path):img = tf.io.read_file(path)img = tf.image.decode_png(img, channels=3)return img

封装加载标签数据解码函数

def read_png_label(path):img = tf.io.read_file(path)img = tf.image.decode_png(img, channels=1)return img

测试dataset中的数据

img_1 = read_png(img[0])
label_1 = read_png_label(label[0])
img_1.shape
label_1.shape
  • TensorShape([1024, 2048, 3])
  • TensorShape([1024, 2048, 1])
plt.imshow(img_1)

plt.imshow(label_1)


数据增强

concat = tf.concat([img_1, label_1], axis=-1)
concat.shape
  • TensorShape([1024, 2048, 4])

用函数tf.concat把图像和标签叠加到一起后图像通道变为4维了,3+1=4

自定义数据增强函数

def crop_img(img, mask):concat_img = tf.concat([img, mask], axis=-1)  #两张图片叠加在一起裁剪concat_img = tf.image.resize(concat_img, (280,280), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) crop_img = tf.image.random_crop(concat_img, [256, 256, 4])   #裁剪return crop_img[:,:,:3], crop_img[:,:,3:]

该函数返回的是图像数据和标签数据,利用切片的方式返回。
测试一下

img_1, label_1 = crop_img(img_1, label_1)
img_1.shape, label_1.shape
  • (TensorShape([256, 256, 3]), TensorShape([256, 256, 1]))
plt.subplot(1,2,1)
plt.imshow(img_1.numpy())
plt.subplot(1,2,2)
plt.imshow(label_1.numpy())   #plt.imshow(np.squeeze(label_1.numpy())


图像形状变化和归一化

def normal(img, mask):img = tf.cast(img, tf.float32)/127.5 -1mask = tf.cast(mask, tf.int32)return img, mask

封装加载训练集图像数预处理

def load_image_train(img_path, mask_path):img = read_png(img_path)mask = read_png_label(mask_path)img, mask = crop_img(img, mask)if tf.random.uniform(())>0.5:img = tf.image.flip_left_right(img)mask = tf.image.flip_left_right(mask)img, mask = normal(img, mask)return img, mask

封装加载测试集图像数预处理,测试集不用做图像增强

def load_image_val(img_path, mask_path):img = read_png(img_path)mask = read_png_label(mask_path)img = tf.image.resize(img, (256, 256))  mask = tf.image.resize(mask, (256, 256)) img, mask = normal(img, mask)return img, mask

构建dataset并应用到自定义函数上

dataset_train = dataset_train.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset_val = dataset_val.map(load_image_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)

设置批次大小和步数

BATCH_SIZE = 32
BUFFER_SIZE = 128
STEP_PER_EPOCH = train_count //BATCH_SIZE
VALIDATION_STEP = val_count //BATCH_SIZE

构建训练集和测试集的输入方式

dataset_train = dataset_train.cache().repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset_train = dataset_train.prefetch(tf.data.experimental.AUTOTUNE)
dataset_val = dataset_val.cache().batch(BATCH_SIZE)
dataset_train
dataset_val
  • <PrefetchDataset shapes: ((None, 256, 256, 3), (None, 256, 256, 1)), types: (tf.float32, tf.int32)>
  • <BatchDataset shapes: ((None, 256, 256, 3), (None, 256, 256, 1)), types: (tf.float32, tf.int32)>

定义模型

label_1中图像的类别

np.unique(label_1.numpy())
  • array([ 1, 4, 7, 11, 14, 15, 17, 20, 21, 22, 23, 26, 27], dtype=uint8)
def creat_model():inputs = tf.keras.layers.Input(shape=(256,256,3))#下采样x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(inputs)x = tf.keras.layers.BatchNormalization()(x)x = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x)x = tf.keras.layers.BatchNormalization()(x)     #256*256*64x1 = tf.keras.layers.MaxPooling2D()(x)       #128*128*64x1 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x1)x1 = tf.keras.layers.BatchNormalization()(x1)  #128*128*128x1 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x1)x1 = tf.keras.layers.BatchNormalization()(x1)  #128*128*128x2 =  tf.keras.layers.MaxPooling2D()(x1) #64*64*128x2 = tf.keras.layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x2)x2 = tf.keras.layers.BatchNormalization()(x2)  #64*64*256x2 = tf.keras.layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x2)x2 = tf.keras.layers.BatchNormalization()(x2)  #64*64*256x3 =  tf.keras.layers.MaxPooling2D()(x2) #32*32*256x3 = tf.keras.layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x3)x3 = tf.keras.layers.BatchNormalization()(x3)  #32*32*512x3 = tf.keras.layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x3)x3 = tf.keras.layers.BatchNormalization()(x3)  #32*32*512x4 =  tf.keras.layers.MaxPooling2D()(x3) #16*16*256x4 = tf.keras.layers.Conv2D(1024, (3, 3), padding='same', activation='relu')(x4)x4 = tf.keras.layers.BatchNormalization()(x4)  #16*16*1024x4 = tf.keras.layers.Conv2D(1024, (3, 3), padding='same', activation='relu')(x4)x4 = tf.keras.layers.BatchNormalization()(x4)  #16*16*1024#上采样x5 = tf.keras.layers.Conv2DTranspose(512, (2, 2), padding="same", strides=2, activation='relu')(x4)  #32*32*512x5 = tf.keras.layers.BatchNormalization()(x5)       x6 = tf.concat([x5, x3], axis=-1) #32*32*1024x6 = tf.keras.layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x6)  #32*32*512x6 = tf.keras.layers.BatchNormalization()(x6)  x6 = tf.keras.layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x6)#32*32*512x6 = tf.keras.layers.BatchNormalization()(x6)  x7 = tf.keras.layers.Conv2DTranspose(256, (2, 2), padding="same", strides=2, activation='relu')(x6)  #64*64*256x7 = tf.keras.layers.BatchNormalization()(x7)  x8 = tf.concat([x7, x2], axis=-1)    #64*64*512x8 = tf.keras.layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x8)  #64*64*256x8 = tf.keras.layers.BatchNormalization()(x8)  x8 = tf.keras.layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x8)#64*64*256x8 = tf.keras.layers.BatchNormalization()(x8)x9 = tf.keras.layers.Conv2DTranspose(128, (2, 2), padding="same", strides=2, activation='relu')(x8)  #128*128*128x9 = tf.keras.layers.BatchNormalization()(x9)  x10 = tf.concat([x9, x1], axis=-1)  #128*128*256x10 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x10)  #128*128*128x10 = tf.keras.layers.BatchNormalization()(x10)  x10 = tf.keras.layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x10)  #128*128*128x10 = tf.keras.layers.BatchNormalization()(x10)x11 = tf.keras.layers.Conv2DTranspose(64, (2, 2), padding="same", strides=2, activation='relu')(x10)  #256*256*64x11 = tf.keras.layers.BatchNormalization()(x11)x12 = tf.concat([x11, x], axis=-1)  #256*256*128x12 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x12)  #256*256*64x12 = tf.keras.layers.BatchNormalization()(x12)  x12 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', activation='relu')(x12)  #256*256*64x12 = tf.keras.layers.BatchNormalization()(x12)output = tf.keras.layers.Conv2D(34, (1,1), padding='same', activation='softmax')(x12)  #256*256*34return tf.keras.Model(inputs=inputs, outputs=output)

建立模型

model = creat_model()

tf.keras.metrics.MeanIoU(num_classes=34) # 根据独热编码进行计算
我们是顺序编码 需要更改类

class MeanIoU(tf.keras.metrics.MeanIoU):def __call__(self, y_true, y_pred, sample_weight=None):y_pred = tf.argmax(y_pred, axis=-1)return super().__call__(y_true, y_pred, sample_weight=sample_weight)

模型编译

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])

设置训练次数

EPOCHS=60

模型训练

history = model.fit(dataset_train,epochs=EPOCHS,steps_per_epoch=STEP_PER_EPOCH,validation_data=dataset_val,validation_steps=VALIDATION_STEP)

训练过程损失函数可视化

plt.plot(history.epoch, history.history.get('loss'), 'r', label='Training loss')
plt.plot(history.epoch, history.history.get('val_loss'), 'b', label='Validation loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

训练过程正确率可视化

plt.plot(history.epoch, history.history.get('acc'), 'r', label='Training acc')
plt.plot(history.epoch, history.history.get('val_acc'), 'b', label='Validation acc')
plt.xlabel('Epoch')
plt.ylabel('acc Value')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

测试

for image, mask in dataset_train.take(1):pred_mask = model.predict(image)pred_mask = tf.argmax(pred_mask, axis=-1)pred_mask = pred_mask[..., tf.newaxis]plt.figure(figsize=(10, 10))for i in range(num):plt.subplot(num, 3, i*num+1)plt.imshow(tf.keras.preprocessing.image.array_to_img(image[i]))plt.subplot(num, 3, i*num+2)plt.imshow(tf.keras.preprocessing.image.array_to_img(mask[i]))plt.subplot(num, 3, i*num+3)plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_mask[i]))

tensorflow2.3实现街景语义分割(二)相关推荐

  1. pytorch 语义分割loss_关于pytorch语义分割二分类问题的两种做法

    形式1:输出为单通道 分析 即网络的输出 output 为 [batch_size, 1, height, width] 形状.其中 batch_szie 为批量大小,1 表示输出一个通道,heigh ...

  2. 街景语义分割数据集总结

    汇总 SYNTHIA-Dataset 一个大规模的虚拟城市的真实感渲染图数据集,带有语义分割信息,是为了在自动驾驶或城市场景规划等研究领域中的场景理解而提出的.提供了11个类别物体(分别为空.天空.建 ...

  3. deeplabv3+街景图片语义分割,无需训练模型,看不懂也没有影响,直接使用。cityscapes

    最近做街景语义分割相关的工作,因为没有gpu训练模型,且训练的模型往往miou很低,并不如别人训练好的权重好用,所以在github找到了一个模型,具有cityscapes数据集预训练权重,不需要训练模 ...

  4. Swin-Unet跑自己的数据集(Transformer用于语义分割)

    原始代码位置: GitHub - HuCaoFighting/Swin-Unet: The codes for the work "Swin-Unet: Unet-like Pure Tra ...

  5. 【论文阅读】SCAttNet:具有空间和通道注意机制的高分辨率遥感图像语义分割网络

    [论文阅读]SCAttNet:具有空间和通道注意机制的高分辨率遥感图像语义分割网络 文章目录 [论文阅读]SCAttNet:具有空间和通道注意机制的高分辨率遥感图像语义分割网络 一.总体介绍 二.概述 ...

  6. 从零开始的图像语义分割:FCN快速复现教程(Pytorch+CityScapes数据集)

    从零开始的图像语义分割:FCN复现教程(Pytorch+CityScapes数据集) 前言 一.图像分割开山之作FCN 二.代码及数据集获取 1.源项目代码 2.CityScapes数据集 三.代码复 ...

  7. UNet-肝脏肿瘤图像语义分割

    目录 一. 语义分割 二. 数据集 三. 数据增强 图像数据处理步骤 CT图像增强方法 :windowing方法 直方图均衡化 获取掩膜图像深度 在肿瘤CT图中提取肿瘤 保存肿瘤数据 四. 数据加载 ...

  8. 憨批的语义分割重制版7——Tensorflow2 搭建自己的Unet语义分割平台

    憨批的语义分割重制版7--Tensorflow2 搭建自己的Unet语义分割平台 注意事项 学习前言 什么是Unet模型 代码下载 Unet实现思路 一.预测部分 1.主干网络介绍 2.加强特征提取结 ...

  9. 语义分割:基于openCV和深度学习(二)

    语义分割:基于openCV和深度学习(二) Semantic segmentation in images with OpenCV 开始吧-打开segment.py归档并插入以下代码: Semanti ...

  10. 深度学习(二十五)基于Mutil-Scale CNN的图片语义分割、法向量估计-ICCV 2015

    基于Mutil-Scale CNN的图片语义分割.法向量估计 原文地址:http://blog.csdn.net/hjimce/article/details/50443995 作者:hjimce 一 ...

最新文章

  1. na na na na na ~
  2. 微信公众帐号开发教程第1篇-引言(转)
  3. Forth Week :快速上手一门编程语言
  4. mysql starting server 失败_安装MySQL过程中(最后starting server)报错解决方案
  5. 群晖python套件包_利用群晖Docker安装ubuntu16.04搭建python网站服务器(部署篇)
  6. 在Ubuntu环境下使用vcpkg安装sqlite_orm包文件
  7. 分数怎么在计算机上关,电脑如何在注册表上关闭AutoRun功能
  8. 17 PP配置-生产计划-总体维护工厂参数
  9. Vue 方法与事件处理器
  10. 80-450-010-原理-MySQL索引
  11. hexo next 主题安装 livere 评论插件
  12. oracle 添加外键,报“未找到父项关键字”
  13. Jupyter Notebook 代码补全功能配置
  14. 软件网站安全性的设计与检测与解决方案
  15. 【储留香系列】如何构建一个拖垮公司的备份系统
  16. mysql中如何计算同比环比_vnpy中如何计算MACD指标
  17. 2018中国食品工业年鉴2017PDF版
  18. CDISC SDTM AE domain学习笔记 - 1
  19. vncview用法_vnc远程桌面怎么使用(最新vncviewer使用教程)
  20. vue 动态修改页面的meta

热门文章

  1. MySQL进阶探索--之STRAIGHT JOIN用法简介
  2. 检测网络是否正常(ping,Telnet,tracert以及tnsping)
  3. android程序联网失败,请检查网络是否可用
  4. 泰拉瑞亚 阿里云服务器搭建记录
  5. SV806 QT UI开发
  6. 普及组noip2015年问题求解——重新排列1234和根节点数为2015的二叉树最多有__个叶子节点
  7. 使用 teredo 穿透NAT访问 ipv6
  8. 浅夏,盈一眸清凉,捻一指馨香
  9. android webview 跳转系统浏览器,webview 调用系统浏览器怎么解决
  10. 分布式系统原理(5)Quorum 机制