从零开始的图像语义分割:FCN复现教程(Pytorch+CityScapes数据集)

  • 前言
  • 一、图像分割开山之作FCN
  • 二、代码及数据集获取
    • 1.源项目代码
    • 2.CityScapes数据集
  • 三、代码复现
    • 1.数据预处理
    • 2.代码修改
    • 3.运行结果
  • 总结
  • 参考网站

前言

摆了两周,突然觉得不能一直再颓废下去了,应该利用好时间,并且上个月就读了一些经典的图像分割论文比如FCN、UNet和Mask R-CNN,但仅仅只是读了论文并且大概了解了图像分割是在做什么任务的,于是今天就拉动手复现一下,因为只有代码运行起来了,才能进行接下来的代码阅读以及其他改进迁移等后续工作。
本文着重在于代码的复现,其他相关知识会涉及得较少,需要读者自行了解。
看完这篇文章,您将收获一个完整的图像分割项目(一个通用的图像分割数据集及一份可正常执行的代码)。

一、图像分割开山之作FCN


图来自FCN,Jonathan Long,Evan Shelhamer,Trevor Darrell CVPR2015

图像分割可以大致为实例分割、语义分割,其中语义分割(Semantic Segmentation)是对图像中每一个像素点进行分类,确定每个点的类别(如属于背景、人或车等),从而进行区域划分。目前,语义分割已经被广泛应用于自动驾驶、无人机落点判定等场景中。
FCN全程Fully Convolutional Networks,最早发表于CVPR2015,原论文链接如下:
FCN论文链接:https://arxiv.org/abs/1411.4038
正如其名称全卷积网络,实则是将早年的网络比如VGG的全连接层代替为卷积层,这样做的目的是让模型可以输入不同尺寸的图像,因为全连接层一旦被创建输入输出维度都是固定的,追根溯源就是输入图片的尺寸固定,并且语义分割是像素级别操作,替换为卷积层也更加合理(卷积操作就是像素级别,这些都是后话了)。
更具体的学习视频可以跳转到b站FCN网络结构详解(语义分割)


二、代码及数据集获取

1.源项目代码


进入FCN论文链接,点击Code&Data再进入Community Code跳转到paperwithcode网站。

很神奇地是会发现有两个FCN的检索链接,本文所需要的pytorch项目代码在红框这个链接中

Star最高的就是本文所需项目,这个大佬还有自己的个人网页,而且号称是FCN最简单的实现,我可以作证此言不虚,的确是众多代码中最简洁明朗的。

2.CityScapes数据集

CityScapes数据集官方下载链接:CityScapes Download
然而下载这个数据集需要注册账号,而且需要的是教育邮箱,可能是按照是否带edu.cn域名判断的吧,本人使用学校邮箱成功注册下载了数据集。读者若有不便可以上网其他途径获取或淘宝买个账号。


只需下载前3个数据集即可,gtFine_trainvaltest是精确标注(最主要最关键部分),gtCoarse是粗略标注,leftimg8bit_trainvaltest是原图。虽然模型训练的时候只需要用到gtFine但是因为接下来还需要预处理数据集,因此要将三个数据集下载好,才能执行官方给的预处理代码。
重构数据集

将三个zip解压然后新建一个文件夹命名为CityScapes,然后将三个解压文件里的内容按上图目录放置好,为数据集预处理做准备。


三、代码复现

1.数据预处理

这里需要先下载官方的脚本:cityscapesScripts
接下来对其中的一些地方进行修改,最重要的两个文件为项目下cityscapesscripts\helpers\labels.py和cityscapesscripts\preparation\createTrainIdLabelImgs.py。


蓝色框为原本的代码,直接注释掉添加红框处代码,即指定自己本地的数据集目录,比如我就将CityScapes放到了E盘的dataset目录下。

然后是在label.py文件里按照训练的需要更改trainid,255为不被模型所需要的id,因为FCN中为19类+背景板,所以为20类,刚好符合所以不需要更改label文件中任何内容。

最后运行createTrainIdLabelImgs.py,如果报错的话大概率是因为缺少上图蓝框所示的库,将其直接注释掉就可以了。

2.代码修改

之所以需要修改是因为原本的代码里面数据预处理那块太慢了,Cityscapes_utils.py要将trainId写入npy文件,运行速度极慢,这也是先前用官方预处理脚本cityscapesScripts来预处理的原因,预处理的目的其实也只是生成TrainIds的mask图片,和labelIds的png图片是同理的,只是每个像素所对应类别按照label.py里面的label表进行改变。
其实pytorch官方有给出加载CityScapes的数据集代码,但其直接拿来用并不能满足我们要求,所以需要修改一下,就原项目代码的Cityscapes_loader.py和torchvision.datasets.Cityscapes的代码结合,得到如下可执行代码。读者只需用其替换train.py文件即可。

# -*- coding: utf-8 -*-
# Author: Reganzhxfrom __future__ import print_functionimport random
from tqdm import tqdm # 由于训练缓慢,添加进度条方便观察
import imageio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import DataLoaderfrom fcn import VGGNet, FCN32s, FCN16s, FCN8s, FCNs
# from Cityscapes_loader import CityScapesDataset
from CamVid_loader import CamVidDataset
from torchvision.datasets import Cityscapes
from matplotlib import pyplot as plt
import numpy as np
import time
import sys
import os
from PIL import Imageclass CityScapesDataset(Cityscapes):def __init__(self, root: str,split: str = "train",mode: str = "fine",target_type="semantic",transform=None,target_transform=None,transforms=None):super(CityScapesDataset, self).__init__(root,split,mode,target_type,transform,target_transform,transforms)self.means = np.array([103.939, 116.779, 123.68]) / 255.self.n_class = 20self.new_h = 512 # 数据集图片过大,需要剪裁self.new_w = 1024def __getitem__(self, index):img = imageio.imread(self.images[index], pilmode='RGB')targets = []for i, t in enumerate(self.target_type):if t == "polygon":target = self._load_json(self.targets[index][i])else:target = imageio.imread(self.targets[index][i])targets.append(target)target = tuple(targets) if len(targets) > 1 else targets[0] # 针对多目标 可不关注h, w, _ = img.shapetop = random.randint(0, h - self.new_h)left = random.randint(0, w - self.new_w)img = img[top:top + self.new_h, left:left + self.new_w]label = target[top:top + self.new_h, left:left + self.new_w]# reduce meanimg = img[:, :, ::-1]  # switch to BGRimg = np.transpose(img, (2, 0, 1)) / 255.img[0] -= self.means[0]img[1] -= self.means[1]img[2] -= self.means[2]# convert to tensorimg = torch.from_numpy(img.copy()).float()label = torch.from_numpy(label.copy()).long()# create one-hot encodingh, w = label.size()target = torch.zeros(self.n_class, h, w)for c in range(self.n_class):target[c][label == c] = 1sample = {'X': img, 'Y': target, 'l': label}return sampledef __len__(self) -> int:return len(self.images)def _get_target_suffix(self, mode: str, target_type: str) -> str:if target_type == "instance":return f"{mode}_instanceIds.png"elif target_type == "semantic": # 让其指向预处理好的target图片return f"{mode}_labelTrainIds.png"elif target_type == "color":return f"{mode}_color.png"else:return f"{mode}_polygons.json"n_class = 20
batch_size = 2 # 根据测试,1batch需要2G显存,请按实际设置
epochs = 500
lr = 1e-4
momentum = 0
w_decay = 1e-5
step_size = 50
gamma = 0.5
configs = "FCNs-BCEWithLogits_batch{}_epoch{}_RMSprop_scheduler-step{}-gamma{}_lr{}_momentum{}_w_decay{}".format(batch_size, epochs, step_size, gamma, lr, momentum, w_decay)
print("Configs:", configs)# create dir for model
model_dir = "models"
if not os.path.exists(model_dir):os.makedirs(model_dir)
model_path = os.path.join(model_dir, configs)use_gpu = torch.cuda.is_available()
num_gpu = list(range(torch.cuda.device_count()))# 自行更改root
train_data = CityScapesDataset(root='E:/datasets/CityScapes', split='train', mode='fine',target_type='semantic')train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)val_data = CityScapesDataset(root='E:/datasets/CityScapes', split='val', mode='fine',target_type='semantic')val_loader = DataLoader(val_data, batch_size=1)vgg_model = VGGNet(requires_grad=True, remove_fc=True)
fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class)if use_gpu:ts = time.time()vgg_model = vgg_model.cuda()fcn_model = fcn_model.cuda()fcn_model = nn.DataParallel(fcn_model, device_ids=num_gpu)print("Finish cuda loading, time elapsed {}".format(time.time() - ts))criterion = nn.BCEWithLogitsLoss()
optimizer = optim.RMSprop(fcn_model.parameters(), lr=lr, momentum=momentum, weight_decay=w_decay)
scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size,gamma=gamma)  # decay LR by a factor of 0.5 every 30 epochs# create dir for score
score_dir = os.path.join("scores", configs)
if not os.path.exists(score_dir):os.makedirs(score_dir)
IU_scores = np.zeros((epochs, n_class))
pixel_scores = np.zeros(epochs)def train():for epoch in range(epochs):scheduler.step()ts = time.time()for iter, batch in enumerate(tqdm(train_loader)):optimizer.zero_grad()if use_gpu:inputs = Variable(batch['X'].cuda())labels = Variable(batch['Y'].cuda())else:inputs, labels = Variable(batch['X']), Variable(batch['Y'])outputs = fcn_model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()if iter % 10 == 0:print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.item()))print("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))torch.save(fcn_model, model_path)val(epoch)def val(epoch):fcn_model.eval()total_ious = []pixel_accs = []for iter, batch in enumerate(val_loader):if use_gpu:inputs = Variable(batch['X'].cuda())else:inputs = Variable(batch['X'])output = fcn_model(inputs)output = output.data.cpu().numpy()N, _, h, w = output.shapepred = output.transpose(0, 2, 3, 1).reshape(-1, n_class).argmax(axis=1).reshape(N, h, w)target = batch['l'].cpu().numpy().reshape(N, h, w)for p, t in zip(pred, target):total_ious.append(iou(p, t))pixel_accs.append(pixel_acc(p, t))# Calculate average IoUtotal_ious = np.array(total_ious).T  # n_class * val_lenious = np.nanmean(total_ious, axis=1)pixel_accs = np.array(pixel_accs).mean()print("epoch{}, pix_acc: {}, meanIoU: {}, IoUs: {}".format(epoch, pixel_accs, np.nanmean(ious), ious))IU_scores[epoch] = iousnp.save(os.path.join(score_dir, "meanIU"), IU_scores)pixel_scores[epoch] = pixel_accsnp.save(os.path.join(score_dir, "meanPixel"), pixel_scores)# borrow functions and modify it from https://github.com/Kaixhin/FCN-semantic-segmentation/blob/master/main.py
# Calculates class intersections over unions
def iou(pred, target):ious = []for cls in range(n_class):pred_inds = pred == clstarget_inds = target == clsintersection = pred_inds[target_inds].sum()union = pred_inds.sum() + target_inds.sum() - intersectionif union == 0:ious.append(float('nan'))  # if there is no ground truth, do not include in evaluationelse:ious.append(float(intersection) / max(union, 1))# print("cls", cls, pred_inds.sum(), target_inds.sum(), intersection, float(intersection) / max(union, 1))return iousdef pixel_acc(pred, target):correct = (pred == target).sum()total = (target == target).sum()return correct / totalif __name__ == "__main__":val(0)  # show the accuracy before trainingtrain()

3.运行结果

分别在自己办公电脑1030显卡(显存4G)和3060显卡(显存12G)上测试,根据两台电脑运行上看每增加1batch就需要消耗2G显存,因为3060上最大只能将batch size设置为6。3060显卡上1个epoch需要8min,也就是说训练完500epoch需要三天时间,可见图像分割真的是极其消耗资源。而1030上1代竟然耗时2h20min,所以按照时间来看首选设备是3090,这样才可能在一天之内进行完一次完整500epoch训练。

第1轮迭代后pixel accuracy就有75%,目前到第25轮pixel accuracy达到85%,随着epoch数增加,pixel acc也越来越高,希望其最终能突破90%,原论文中可是达到96%pixel准确率。

下图为3060上训练150epoch的结果,每5epoch进行一次val评估。最后使用matplotlib绘制如下曲线,pixel_acc和meanIoU的获取请读者自行额外编写代码获得,此处仅提供绘图代码。
第135epoch取得最高pixel accuracy=0.8766716842651368,meanIoU=0.3268041800950261

from matplotlib import pyplot as pltx=[i for i in range(0,151,5)] #横坐标
# 此处给出我的数据,浮点数都用round函数取到小数点后7位
pix_acc_list=[0.7520696,0.7918097,0.6557526,0.8310604,0.8453417,0.8509236,0.8534471,0.8378322,0.8489639,0.8563263,0.8538324,0.8572157,0.860767,0.8660216,0.8631711,0.8631837,0.8670352,0.8597714,0.8689239,0.8647407,0.8698506,0.8712046,0.8719427,0.8722804,0.8732114,0.871852,0.8714358,0.8766717,0.86854,0.8661136,0.8761132]
meanIoU_list=[0.1333057,0.185366,0.1383637,0.2432535,0.2634509,0.2799635,0.2831553,0.2642947,0.2924905,0.3027259,0.3123738,0.2976701,0.3113799,0.3239229,0.3163488,0.3170467,0.3246953,0.3236825,0.3242375,0.3262411,0.3355112,0.3285704,0.3388148,0.328427,0.3378653,0.3385619,0.3358321,0.3268042,0.3297385,0.3347885,0.3379351]
plt.figure()
plt.plot(x,pix_acc_list,color='blue',label='pixel acc')
plt.plot(x,meanIoU_list,color='red',label='meanIoU')plt.xticks(fontsize=16)
plt.yticks(fontsize=16)plt.xlabel('Epoch',fontsize=20)
plt.ylabel('Score',fontsize=20)
plt.legend(fontsize=16)
plt.show()

总结

希望您读到这里能有所收获,本文所参考资料也在文末给出,大家可以查阅获取更多知识细节,后续还将不断完善本文内容,敬请期待……


参考网站

https://bbs.huaweicloud.com/blogs/306716
https://developer.aliyun.com/article/797607
https://www.cnblogs.com/dotman/p/cityscapes_dataset_tips.html
https://zhuanlan.zhihu.com/p/147195575
https://codeantenna.com/a/uD5sJceaS1
https://blog.csdn.net/zz2230633069/article/details/84591532
https://www.zhihu.com/question/276325769/answer/2418207657
https://blog.csdn.net/zz2230633069/article/details/84668984
https://blog.csdn.net/yumaomi/article/details/124847721

从零开始的图像语义分割:FCN快速复现教程(Pytorch+CityScapes数据集)相关推荐

  1. Pytorch:图像语义分割-FCN, U-Net, SegNet, 预训练网络

    Pytorch: 图像语义分割-FCN, U-Net, SegNet, 预训练网络 Copyright: Jingmin Wei, Pattern Recognition and Intelligen ...

  2. 遥感图像语义分割——从原始图像开始制作自己的数据集(以高分二号为例)

    遥感图像语义分割--从原始图像开始制作自己的数据集(以高分二号为例) 文章目录 遥感图像语义分割--从原始图像开始制作自己的数据集(以高分二号为例) 1.遥感影像获取 2.遥感数据预处理(影像融合) ...

  3. 图像语义分割 -- FCN

    一:图像语义分割 最简答理解图像语义分割呢就是看下面的图片例子: 像素级别的分类: 假如像素有五个类别,那么最后输出的结果在长度和宽度上是一样的,只不过通道数就是类别个数了.拆解开各个通道就是如下所示 ...

  4. 36.图像语义分割-FCN

    图像语义分割是计算机读懂图像的基础,所以叫图像语义分割,左侧是图像语义分割,右侧是实例分割,语义分割关注种类,实例分割关注个体,像我们左侧的语义分割,分割后机器就能大致了解,图里有5只羊,1个人,1条 ...

  5. 图像语义分割实战:TensorFlow Deeplabv3+ 训练自己数据集

    文章目录 前言 一.环境配置 二.训练过程 1.引入库 2.数据集准备 转换为 VOC 格式的数据集 Convert to 灰度图 Convert to tfrecord 3.训练前代码准备 4.主要 ...

  6. 图像语义分割方法研究进展

    全监督学习的图像语义分割方法研究进展 简介 1 全监督学习的图像语义分割方法 1.1 基于全卷积的图像语义分割方法 1.2 基于编码器解码器结构的图像语义分割方法 1.3 基于注意力机制的图像语义分割 ...

  7. 深度学习:使用UNet做图像语义分割,训练自己制作的数据集,详细教程

    语义分割(Semantic Segmentation)是图像处理和机器视觉一个重要分支.与分类任务不同,语义分割需要判断图像每个像素点的类别,进行精确分割.语义分割目前在自动驾驶.自动抠图.医疗影像等 ...

  8. 图像语义分割样本制作——使用Matlab模块Image Labeler 标记样本

    在进行图像语义分割的时候,需要自己制作数据集,目前开源的标记软件很多,但个人觉得最好用的还是MATLAB中的Image Labeler,下面是简单的使用介绍. 1. 定义各类别说明: 2. 工具栏介绍 ...

  9. 视频教程-DeepLabv3+图像语义分割实战:训练自己的数据集-计算机视觉

    DeepLabv3+图像语义分割实战:训练自己的数据集 大学教授,美国归国博士.博士生导师:人工智能公司专家顾问:长期从事人工智能.物联网.大数据研究:已发表学术论文100多篇,授权发明专利10多项 ...

最新文章

  1. 瑞士军刀——Pandoc
  2. python项目主界面_python项目案例
  3. Interface 的本质用处
  4. 第十天2017/04/23(1、企业财富库:“循环单链表”的设计与实现)
  5. Abp vNext 切换MySql数据库
  6. linux mysql 分区_Linux :linux磁盘分区(普通分区2T以内),安装免安装版mysql(tar.gz)...
  7. dbc数据库 与 mysql_【图片】DBC2000安装及数据库详细解析(不断更行中......)【dbc2000吧】_百度贴吧...
  8. centos下载和安装mongodb
  9. DOS命令打开一个软件,以及在python中的使用
  10. O(lgn)计算斐波那契数
  11. UMTS到LTE的系统架构演进(学习整理:LTE完全指南-LTE、LTE-Advanced、SAE、VolTE和4G移动通信)
  12. 查找网络计算机步骤,如何查找到局域网中指定IP地址的是哪一台电脑
  13. Pytorch节省显存、加速训练的小技巧
  14. DevExpress v18.2版本亮点——Reporting篇(三)
  15. 【下载Tomcat旧版本】
  16. gstreamer+qgc+aarch64
  17. python elementtree乱码_Python中使用ElementTree解析xml
  18. HMC7043和HMC7044芯片配置使用
  19. 中科大辅修计算机,中科大新生入学第二考来了——校规考试!(一不小心就挂)...
  20. linux 电子书阅读器_3个适用于Linux桌面的电子书阅读器

热门文章

  1. Mysql根据身份证更新出生日期及年龄sql语句
  2. FANUC UIO 发那科专用IO定义说明
  3. SIMetrix教程-005.SIMetrix导入第三方库;SIMetrix导入模型
  4. QPSK Matlab仿真
  5. Java 代码基于开源组件生成带头像的二维码
  6. Invictus -- 不可征服
  7. 【腾讯TMQ】APP省流量更新监控最佳实践
  8. 服务定位器 - Caliburn.Micro 文档系列
  9. 【新书推荐】【2020.02】可穿戴设备传感器和天线的设计与优化
  10. 电脑课破解学生端控屏软件