Pspnet全名Pyramid Scene Parsing Network,论文地址:Pyramid Scene Parsing Network

论文名就是《Pyramid Scene Parsing Network》。

该模型提出是为了解决场景分析问题。针对FCN网络在场景分析数据集上存在的问题,Pspnet提出一系列改进方案,以提升场景分析中对于相似颜色、形状的物体的检测精度。

图1 ADE20k场景分析

作者在ADE20K数据集上进行实验时,主要发现有如下3个问题:

  1. 错误匹配,FCN模型把水里的船预测成汽车,但是汽车是不会在水上的。因此,作者认为FCN缺乏收集上下文能力,导致了分类错误。
  2. 作者发现相似的标签会导致一些奇怪的错误,比如earth和field,mountain和hill,wall,house,building和skyscraper。FCN模型会出现混淆。
  3. 第三是小目标的丢失问题,像一些路灯、信号牌这种小物体,很难被FCN所发现。相反的,一些特别大的物体预测中,在感受野不够大的情况下,往往会丢失一部分信息,导致预测不连续。

为了解决这些问题,作者提出了Pyramid Pooling Module。

Pyramid Pooling Module

作者在文章中提出了Pyramid Pooling Module(池化金字塔结构)这一模块。

作者提到,在深层网络中,感受野的大小大致上体现了模型能获得的上下文新消息。尽管在理论上Resnet的感受野已经大于图像尺寸,但是实际上会小得多。这就导致了很多网络不能充分的将上下文信息结合起来,于是作者就提出了一种全局的先验方法-全局平均池化。

作者在PPM模块中并联了四个不同大小的全局池化层,将原始的feature map池化生成不同级别的特征图,经过卷积和上采样恢复到原始大小。这种操作聚合了多尺度的图像特征,生成了一个“hierarchical global prior”,融合了不同尺度和不同子区域之间的信息。最后,这个先验信息再和原始特征图进行相加,输入到最后的卷积模块完成预测。

图2 Pspnet

Pspnet的核心就是PPM模块。其网络架构十分简单,backbone为resnet网络,将原始图像下采样8倍成特征图,特征图输入到PPM模块,并与其输出相加,最后经过卷积和8倍双线性差值上采样得到结果(图2)。

辅助损失

图3 辅助损失

论文中还有一个细节是辅助损失(auxiliary loss),在resnet101的res4b22层引出一条FCN分支,用于计算辅助损失。论文里设置了赋值损失loss2的权重为0.4。则最终的损失则为:

论文复现

本文主要在CamVid数据集上进行复现,数据集可以在另一篇博客中找到CamVid数据集的创建和使用。

Resnet

这里调用了pytorch官方写的ResNet101,替换最后两个layer为dialation模式,只采用8倍下采样。引出layer3的计算结果用于计算辅助损失。

from torchvision.models import resnet50, resnet101
from torchvision.models._utils import IntermediateLayerGetter
import torch
import torch.nn as nnbackbone=IntermediateLayerGetter(resnet101(pretrained=False, replace_stride_with_dilation=[False, True, True]),return_layers={'layer3':'aux','layer4': 'stage4'})x = torch.randn(1, 3, 224, 224).cpu()
result = backbone(x)
for k, v in result.items():print(k, v.shape)

pspnet

from torchvision.models import resnet50, resnet101
from torchvision.models._utils import IntermediateLayerGetter
import torch
import torch.nn as nnclass PPM(nn.ModuleList):def __init__(self, pool_sizes, in_channels, out_channels):super(PPM, self).__init__()self.pool_sizes = pool_sizesself.in_channels = in_channelsself.out_channels = out_channelsfor pool_size in pool_sizes:self.append(nn.Sequential(nn.AdaptiveMaxPool2d(pool_size),nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),))def forward(self, x):out_puts = []for ppm in self:ppm_out = nn.functional.interpolate(ppm(x), size=x.size()[-2:], mode='bilinear', align_corners=True)out_puts.append(ppm_out)return out_putsclass PSPHEAD(nn.Module):def __init__(self, in_channels, out_channels,pool_sizes = [1, 2, 3, 6],num_classes=3):super(PSPHEAD, self).__init__()self.pool_sizes = pool_sizesself.num_classes = num_classesself.in_channels = in_channelsself.out_channels = out_channelsself.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)self.final = nn.Sequential(nn.Conv2d(self.in_channels + len(self.pool_sizes)*self.out_channels, self.out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(self.out_channels),nn.ReLU(),)def forward(self, x):out = self.psp_modules(x)out.append(x)out = torch.cat(out, 1)out = self.final(out)return out# 构建一个FCN分割头,用于计算辅助损失
class Aux_Head(nn.Module):def __init__(self, in_channels=1024, num_classes=3):super(Aux_Head, self).__init__()self.num_classes = num_classesself.in_channels = in_channelsself.decode_head = nn.Sequential(nn.Conv2d(self.in_channels, self.in_channels//2, kernel_size=3, padding=1),nn.BatchNorm2d(self.in_channels//2),nn.ReLU(),            nn.Conv2d(self.in_channels//2, self.in_channels//4, kernel_size=3, padding=1),nn.BatchNorm2d(self.in_channels//4),nn.ReLU(),            nn.Conv2d(self.in_channels//4, self.num_classes, kernel_size=3, padding=1),)def forward(self, x):return self.decode_head(x)class Pspnet(nn.Module):def __init__(self, num_classes, aux_loss = True):super(Pspnet, self).__init__()self.num_classes = num_classesself.backbone = IntermediateLayerGetter(resnet50(pretrained=False, replace_stride_with_dilation=[False, True, True]),return_layers={'layer3':"aux" ,'layer4': 'stage4'})self.aux_loss = aux_lossself.decoder = PSPHEAD(in_channels=2048, out_channels=512, pool_sizes = [1, 2, 3, 6], num_classes=self.num_classes)self.cls_seg = nn.Sequential(nn.Conv2d(512, self.num_classes, kernel_size=3, padding=1),)if self.aux_loss:self.aux_head = Aux_Head(in_channels=1024, num_classes=self.num_classes)def forward(self, x):_, _, h, w = x.size()feats = self.backbone(x) x = self.decoder(feats["stage4"])x = self.cls_seg(x)x = nn.functional.interpolate(x, size=(h, w),mode='bilinear', align_corners=True)# 如果需要添加辅助损失if self.aux_loss:aux_output = self.aux_head(feats['aux'])aux_output = nn.functional.interpolate(aux_output, size=(h, w),mode='bilinear', align_corners=True)return {"output":x, "aux_output":aux_output}return {"output":x}if __name__ == "__main__":model = Pspnet(num_classes=3, aux_loss=True)model = model.cuda()a = torch.ones([2, 3, 224, 224])a = a.cuda()for name, out in model(a).items():print(name,": ", out.shape)

数据集构建

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):"""CamVid Dataset. Read images, apply augmentation and preprocessing transformations.Args:images_dir (str): path to images foldermasks_dir (str): path to segmentation masks folderclass_values (list): values of classes to extract from segmentation maskaugmentation (albumentations.Compose): data transfromation pipeline (e.g. flip, scale, etc.)preprocessing (albumentations.Compose): data preprocessing (e.g. noralization, shape manipulation, etc.)"""def __init__(self, images_dir, masks_dir):self.transform = A.Compose([A.Resize(448, 448),A.HorizontalFlip(),A.VerticalFlip(),A.Normalize(),ToTensorV2(),]) self.ids = os.listdir(images_dir)self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]def __getitem__(self, i):# read dataimage = np.array(Image.open(self.images_fps[i]).convert('RGB'))mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))image = self.transform(image=image,mask=mask)return image['image'], image['mask'][:,:,0]def __len__(self):return len(self.ids)# 设置数据集路径
DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')train_dataset = CamVidDataset(x_train_dir, y_train_dir,
)
val_dataset = CamVidDataset(x_valid_dir, y_valid_dir,
)train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True,drop_last=True)

模型训练

from d2l import torch as d2l
from tqdm import tqdm
import pandas as pdmodel = Pspnet(num_classes=32, aux_loss=True)
model = model.cuda()# training loop 100 epochs
epochs_num = 100
# 选用SGD优化器来训练
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
schedule = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50)# 损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss(ignore_index=255)def evaluate(net, data_iter, device=torch.device('cuda:0')):net.eval()metric = d2l.Accumulator(3)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)pred = net(X)['output']metric.add(d2l.accuracy(pred, y), d2l.size(y))return metric[0] / metric[1]# 训练函数
def train_ch13(net, train_iter, test_iter, loss_func, optimizer, num_epochs, schedule, devices=d2l.try_all_gpus()):timer, num_batches = d2l.Timer(), len(train_iter)animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'])net = nn.DataParallel(net, device_ids=devices).to(devices[0])# 用来保存一些训练参数loss_list = []train_acc_list = []test_acc_list = []epochs_list = []time_list = []lr_list = []for epoch in range(num_epochs):# metric: loss, accuracy, labels.shape[0], labels.numel(), 0.4*aux_lossmetric = d2l.Accumulator(5)for i, (X, labels) in enumerate(train_iter):timer.start()if isinstance(X, list):X = [x.to(devices[0]) for x in X]else:X = X.to(devices[0])gt = labels.long().to(devices[0])net.train()optimizer.zero_grad()result = net(X)seg_loss = loss_func(result['output'], gt)aux_loss = loss_func(result['aux_output'], gt)loss_sum = seg_loss + 0.4*aux_lossl = loss_sumloss_sum.sum().backward()optimizer.step()acc = d2l.accuracy(result['output'], gt)metric.add(l, acc, labels.shape[0], labels.numel(), 0.4*aux_loss)timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[2], metric[1] / metric[3], None, None))test_acc = evaluate(net, test_iter)animator.add(epoch + 1, (None, None, test_acc)) schedule.step()print(f"epoch {epoch+1}/{epochs_num} --- loss {metric[0]/metric[2]:.3f} --- aux_loss {metric[4]/metric[2]:.3f} --- train acc {metric[1]/metric[3]:.3f} --- test acc {test_acc:.3f} --- lr {optimizer.state_dict()['param_groups'][0]['lr']} --- cost time {timer.sum()}")#---------保存训练数据---------------df = pd.DataFrame()loss_list.append(metric[0] / metric[2])train_acc_list.append(metric[1] / metric[3])test_acc_list.append(test_acc)epochs_list.append(epoch+1)time_list.append(timer.sum())lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])df['epoch'] = epochs_listdf['loss'] = loss_listdf['train_acc'] = train_acc_listdf['test_acc'] = test_acc_listdf["lr"] = lr_listdf['time'] = time_listdf.to_excel("../blork_file/savefile/PSPNET.xlsx")#----------------保存模型------------------- if np.mod(epoch+1, 5) == 0:torch.save(net, f'../blork_file/checkpoints/PSPNET{epoch+1}.pth')# 保存下最后的modeltorch.save(net, f'../blork_file/checkpoints/PSPNET.pth')

训练结果

语义分割系列5-Pspnet(pytorch实现)相关推荐

  1. 语义分割系列14-DMNet(pytorch)实现

    DMNet:<Dynamic Multi-Scale Filters for Semantic Segmentation> 发布于2019ICCV. 有意思的是,DMNet的作者和APCN ...

  2. 语义分割系列6-Unet++(pytorch实现)

    目录 Unet++网络 Dense connection deep supervision 模型复现 Unet++ 数据集准备 模型训练 训练结果 Unet++:<UNet++: A Neste ...

  3. 语义分割系列2-Unet(pytorch实现)

    Unet发布于MICCAI.其论文的名字也说得相对很明白,用于生物医学图像分割. <U-Net: Convolutional Networks for Biomedical Image Segm ...

  4. 语义分割系列24-PointRend(pytorch实现)

    PointRend: Image Segmentation as Rendering 论文链接:PointRend 本文将介绍: PointRend的原理 PointRend代码实现 PointRen ...

  5. 语义分割系列19-EMANet(pytorch实现)

    EMANet:<Expectation-Maximization Attention Networks for Semantic Segmentation> 发布于2019ICCV,一作的 ...

  6. 语义分割系列11-DAnet(pytorch实现)

    DAnet:Dual Attention Network for Scene Segmentation 发布于CVPR2019,本文将进行DAnet的论文讲解和复现工作. 论文部分 主要思想 DAne ...

  7. 语义分割系列论文 ParseNet

    语义分割系列论文 ParseNet 核心思想--Global Context 理论感受野的大小(Receptive Field) 实际感受野的大小 此文章如何扩大感受野? 疑点(读者可以忽略本节) 总 ...

  8. 憨批的语义分割重制版9——Pytorch 搭建自己的DeeplabV3+语义分割平台

    憨批的语义分割重制版9--Pytorch 搭建自己的DeeplabV3+语义分割平台 注意事项 学习前言 什么是DeeplabV3+模型 代码下载 DeeplabV3+实现思路 一.预测部分 1.主干 ...

  9. 憨批的语义分割重制版6——Pytorch 搭建自己的Unet语义分割平台

    憨批的语义分割重制版6--Pytorch 搭建自己的Unet语义分割平台 注意事项 学习前言 什么是Unet模型 代码下载 Unet实现思路 一.预测部分 1.主干网络介绍 2.加强特征提取结构 3. ...

  10. 语义分割系列1-FCN(全卷积网络)(pytorch实现)

    全卷积网络FCN(Fully Convolutional Networks)是CV中语义分割任务的开山之作.FCN网络在PASCAL VOC(2012)数据集上获得了62.2%的mIoU. 论文全名& ...

最新文章

  1. D001斯图加特~计算机
  2. 下载nodejs的mysql安装包下载_nodejs安装包下载|nodejs(javascript运行环境) v5.3.0 最新稳定版 - 软件下载 - 绿茶软件园|33LC.com...
  3. httpposterror_http请求405错误方法不被允许的解决 (Method not allowed)
  4. linux下sudo命令[转]
  5. iOS:UIView的block函数实现转场动画---双视图
  6. mysql 默认事务隔离级别_一文读懂MySQL的事务隔离级别及MVCC机制
  7. 千亿智慧照明市场背后,BLEMESH免开发方案成主流
  8. 【三维路径规划】基于matlab A_star算法无人机三维路径规划【含Matlab源码 1387期】
  9. android pdf阅读器推荐,四款好用的PDF阅读器推荐,建议收藏!
  10. DIY一个VR小钢炮
  11. ENVI操作:监督分类
  12. ASP.NET公司企业网站源码
  13. 虚拟机教程(一) 启用win10自带虚拟机
  14. camera中lookAt的理解
  15. HashMap灵魂26问
  16. Android 各版本演变特性整理
  17. python计算机体系三层结构_python学习笔记-计算机结构、操作系统
  18. svg嵌套svg_使用SVG掩盖效果
  19. 安装Oracle Instant Client
  20. javase学习笔记,学习时间一个月,发布笔记进度1/3

热门文章

  1. 【阿朱一帖看尽】2014年BAT到底干了些什么
  2. 淘宝首页原生js练习(基础练习的基础中的基础)
  3. VB二维码生成与解码的代码,特别支持中文的二维码编码译码
  4. 深入理解DRM(二)——了解Widevine与OEMCrypto
  5. 什么是AD域?域能给公司带来什么好处?哪款AD域管理工具比较好?
  6. ivx中字体显示_Windows 7 中的 SimSun-ExtB 是什么字体,为何与中易宋体 SimSun 显示出来不一样?...
  7. java并发圣经,差距不止一星半点!Github星标85K的性能优化法则圣经
  8. python修改pdf内容_用Python把PDF文件转换成Word文档
  9. java实现获取中国大学名称列表、即所在省份
  10. Java——时间日期格式化