文章目录

  • 摘要
    • 目标
    • 贡献点
  • 介绍
  • 相关工作
    • 经典方法
    • 深度学习
    • 基于用户输入的图像分割
    • 基于CNN的弱监督分割
    • 端到端的可微分割CNN
  • 方法
    • 问题建模
    • 网络结构
    • 损失函数
    • 网络更新
  • 实验结果
    • 连续性损失的有效性
    • 用户输入涂鸦的分割
    • 参考图像预训练的效果
  • 源代码解析
  • 加scribble的运行结果

摘要

目标

  • 特征相似的像素应该被分为同一类别
  • 空间上距离近的像素应该被分为同一类别
  • 类别的数量应尽可能多
    三个指标有互相排斥的地方,但应该做到一个平衡

贡献点

  • 通过normalization和argmax实现可微的聚类
  • 空间连续性损失函数
  • 扩展了可以让用户输入的涂鸦,使结果更精确
  • 扩展一个预训练方法

介绍

把摘要再扩展地讲了一遍,提到之前的工作超像素提取+线性迭代聚类只满足了空间连续性

相关工作

经典方法

  • K-means:矢量量化的标准方法
  • GS(基于图的方法):利用特定的区域比较函数,针对全局或局部特征进行简单的贪心算法选择

深度学习

  • MsLRR:有监督和无监督通用,但因为基于超像素,和本篇之前的工作一样具有边界的限制
  • W-Net:无监督,估计分割后恢复原图,无惧边界
  • Unsupervised learning of foreground object segmentation:无监督,但只是前后景分割

基于用户输入的图像分割

  • Graph cut:最小化图像像素对应于节点的图的代价,可以用于涂鸦或者锚定框的输入
  • Image matting:抠图是像素标签软分配,图切割对每个像素就是前后景的分割
  • Constrained random walks:可以根据涂鸦划分给出前/后景的种子

以上方法都只能产生一个二值mask,基于多标签的无监督分割,有以下扩展算法

α-expansion:找到一个局部最小值使得α标签的像素不再轻易增加
α-β swap:找到一个局部最小值使得α和β标签不再被轻易地交换

基于CNN的弱监督分割

常用的语义分割弱监督标签:物体检测锚定框、图片的分类结果、涂鸦
一般流程:根据弱监督标签生成一个训练目标,使用这个训练目标训练网络,两个步骤交替进行

  • ScribbleSup:使用超像素把涂鸦扩展到整张图,再进行训练
  • e-SVM:用CPMC分割从锚定框得到像素级别的标记,再进行训练
  • Distinct class-specific saliency maps for weakly supervised semantic segmentation:根据图片类标签生成class saliency map再送入全卷积CRF网络训练

但这些弱监督方法有可能无法收敛到正确的结果,有以下改进版的端到端CNN方法

端到端的可微分割CNN

关于图像分割的深度学习研究一致围绕对图像特征的理解和提取。

  • deep embedded clustering (DEC):最小化 KL divergence loss,本文提出的方法只是简单地最小化 softmax loss
  • maximum margin clustering:半监督
  • discriminative clustering:半监督

方法

问题建模

f:提取特征
g:分配标签
c:标签
无监督方法,f和g是固定的,c待学习
有监督方法,c是固定的,f和g待学习
分解成两个子问题

  1. 用固定的f和g优化c
  2. 再用固定的c优化f和g

网络结构


一张RGB图像提取特征后通过一个1x1的卷积转换到一个q维的聚类空间(图中q=3),沿着这个空间的q个轴,通过batch norm把这个q维的特征向量归一化,使用argmax确定每个像素的标签是q维中的哪一维,根据这个确定的伪标签计算特征相似度损失和空间连续性损失,再反向传播。


不考虑batch,我理解的维度变化是这样的:HxW>提取特征后并转换空间>HxWxq>确定伪标签后>HxW

在训练网络时,先设一个较大的q,随着损失下降,q会逐渐变小,为了防止q变成1,所以需要对response map做一个归一化。

损失函数

  • 基本:
  • 加入涂鸦:
  • 特征相似误差:

    c n c_n cn是根据 r n r_n rn进行argmax得到的,所以当i遍历到得到 c n c_n cnr n r_n rn时才有ln值累加,因为归一化后 r n r_n rn都是0到1,所以前面有个负号。
  • 空间连续性损失:

    计算每个像素上下左右的的response map上的值的差别
  • 涂鸦损失:

网络更新

前面提到的分解为两个子问题,实际上就是CNN前向计算和反向传播的过程。
使用随机梯度下降、Xavier初始化

Xavier初始化的基本思想是,若对于一层网络的输出和输出可以保持正态分布且方差相近,这样就可以避免输出趋向于0,从而避免梯度弥散情况。https://www.zhihu.com/search?type=content&q=Xavier%20%E5%88%9D%E5%A7%8B%E5%8C%96

另外一个重要的点就是与有确切标签的有监督学习不同,本方法在最后一个卷积层和argmax层之间加的batch norm非常关键,把response map中的每一个轴都归一化到0均值,单位方差,这样每个轴才能平均地进行比较,进而得到正确的类标签。

实验结果

连续性损失的有效性

关于特征相似性损失和连续性损失之间的比例,根据数据集需要的分割精度,设置不同的比例可以达到更好的效果。



用户输入涂鸦的分割

参考图像预训练的效果

源代码解析

github地址:https://github.com/kanezaki/pytorch-unsupervised-segmentation-tip/blob/master/demo.py

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import cv2
import sys
import numpy as np
import torch.nn.init
import randomuse_cuda = torch.cuda.is_available()parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')
parser.add_argument('--scribble', action='store_true', default=False, help='use scribbles')
parser.add_argument('--nChannel', metavar='N', default=100, type=int, help='number of channels')
parser.add_argument('--maxIter', metavar='T', default=1000, type=int, help='number of maximum iterations')
parser.add_argument('--minLabels', metavar='minL', default=3, type=int, help='minimum number of labels')
parser.add_argument('--lr', metavar='LR', default=0.1, type=float, help='learning rate')
parser.add_argument('--nConv', metavar='M', default=2, type=int, help='number of convolutional layers')
parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, help='visualization flag')
parser.add_argument('--input', metavar='FILENAME',help='input image file name', required=True)
parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float,help='step size for similarity loss', required=False)
parser.add_argument('--stepsize_con', metavar='CON', default=1, type=float, help='step size for continuity loss')
parser.add_argument('--stepsize_scr', metavar='SCR', default=0.5, type=float, help='step size for scribble loss')
args = parser.parse_args()# CNN model
class MyNet(nn.Module):def __init__(self,input_dim):super(MyNet, self).__init__()self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1 )self.bn1 = nn.BatchNorm2d(args.nChannel)self.conv2 = nn.ModuleList()self.bn2 = nn.ModuleList()# 参数里面nConv设为2,所以这里的conv2也只包含一个卷积层,输入和输出通道都是100for i in range(args.nConv-1):self.conv2.append( nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1 ) )self.bn2.append( nn.BatchNorm2d(args.nChannel) )# 最后一层是1x1的卷积核,输出为100,即q=100,分类数如参数设置为3-100,随着网络更新动态变化self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0 )self.bn3 = nn.BatchNorm2d(args.nChannel)def forward(self, x):x = self.conv1(x)x = F.relu( x )x = self.bn1(x)for i in range(args.nConv-1):x = self.conv2[i](x)x = F.relu( x )x = self.bn2[i](x)x = self.conv3(x)x = self.bn3(x)return x# load image
im = cv2.imread(args.input)
# 把图像(H,W,C)变成(C,H,W),并把像素值归一化到0-1之间
data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) )
if use_cuda:data = data.cuda()
data = Variable(data)# load scribble
if args.scribble:# 这里是读取之前准备好的二值的涂鸦图片mask = cv2.imread(args.input.replace('.'+args.input.split('.')[-1],'_scribble.png'),-1)# reshape成一维,长度=HxWmask = mask.reshape(-1)# 去除重复数字mask_inds = np.unique(mask)# 删掉255,剩下的就是涂鸦上的颜色mask_inds = np.delete( mask_inds, np.argwhere(mask_inds==255) )# 返回mask中=255的索引(空白)inds_sim = torch.from_numpy( np.where( mask == 255 )[ 0 ] )# 返回mask中!=255的索引(画了涂鸦的像素)inds_scr = torch.from_numpy( np.where( mask != 255 )[ 0 ] )# mask的int型,这里要把源代码的np.int改成np.int64(因为我会报一个类型不统一的错)target_scr = torch.from_numpy( mask.astype(np.int64) )if use_cuda:inds_sim = inds_sim.cuda()inds_scr = inds_scr.cuda()target_scr = target_scr.cuda()target_scr = Variable( target_scr )# set minLabels# 根据涂鸦上的颜色类别确定最小的,按照readme里面的测试,剩下0和8两种值,我猜想8应该是代表那些涂鸦线条的边缘args.minLabels = len(mask_inds)# train
model = MyNet( data.size(1) )
if use_cuda:model.cuda()
model.train()# similarity loss definition
loss_fn = torch.nn.CrossEntropyLoss()# scribble loss definition
loss_fn_scr = torch.nn.CrossEntropyLoss()# continuity loss definition
loss_hpy = torch.nn.L1Loss(size_average = True)
loss_hpz = torch.nn.L1Loss(size_average = True)HPy_target = torch.zeros(im.shape[0]-1, im.shape[1], args.nChannel)
HPz_target = torch.zeros(im.shape[0], im.shape[1]-1, args.nChannel)
if use_cuda:HPy_target = HPy_target.cuda()HPz_target = HPz_target.cuda()optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
# 随机生成100种颜色作为标签颜色
label_colours = np.random.randint(255,size=(100,3))for batch_idx in range(args.maxIter):# forwardingoptimizer.zero_grad()output = model( data )[ 0 ]# 从(C,H,W)转到(H,W,C),再用contiguous做一个拷贝和之前的output区分开,再view成(HxW,C)的形状output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel )# 不太明白这里为什么用reshape而不是permute,这里得到的就是100个通道的response mapoutputHP = output.reshape( (im.shape[0], im.shape[1], args.nChannel) )# 求空间连续性误差,先y和z方向分别求出右边像素-左边像素(左右指的是索引)HPy = outputHP[1:, :, :] - outputHP[0:-1, :, :]HPz = outputHP[:, 1:, :] - outputHP[:, 0:-1, :]lhpy = loss_hpy(HPy,HPy_target)lhpz = loss_hpz(HPz,HPz_target)# 返回的ignore代表每个像素对应的100个通道中值最大的那个通道的值,target返回的是那个通道对应的索引(也就是标签)ignore, target = torch.max( output, 1 )# target的形状应该是(HxW,1),即每一个像素都被分配了标签im_target = target.data.cpu().numpy()# 去除重复的标签得到标签总数nLabels = len(np.unique(im_target))if args.visualize:# 按照HxW的形状,对每一个像素按照标签赋值之前随机初始化的100种颜色im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target])# 把(HxW,3)变成(H,W,3)im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )cv2.imshow( "output", im_target_rgb )cv2.waitKey(10)# lossif args.scribble:loss = args.stepsize_sim * loss_fn(output[ inds_sim ], target[ inds_sim ]) + args.stepsize_scr * loss_fn_scr(output[ inds_scr ], target_scr[ inds_scr ]) + args.stepsize_con * (lhpy + lhpz)else:loss = args.stepsize_sim * loss_fn(output, target) + args.stepsize_con * (lhpy + lhpz)loss.backward()optimizer.step()print (batch_idx, '/', args.maxIter, '|', ' label num :', nLabels, ' | loss :', loss.item())if nLabels <= args.minLabels:print ("nLabels", nLabels, "reached minLabels", args.minLabels, ".")break# save output image
if not args.visualize:output = model( data )[ 0 ]output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel )ignore, target = torch.max( output, 1 )im_target = target.data.cpu().numpy()im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target])im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )
cv2.imwrite( "output.png", im_target_rgb )

加scribble的运行结果

python demo.py --input ./PASCAL_VOC_2012/2007_001774.jpg --scribble


后面打算用别的图片加scribble测试一下,发现scribble其实是比较特殊的,不是简单的在图片上画一些线段,scribble上面的线段应该是0,然后有一个大于0小于255的值代表轮廓,区分前景和后景。暂时对这一部分不太了解,后面再慢慢看吧。

【论文阅读】Unsupervised Learning of Image Segmentation Based on Differentiable Feature Clustering相关推荐

  1. 论文阅读笔记:Retinal vessel segmentation based on Fully Convolutional Neural Networks

    基于全卷积神经网络的视网膜血管分割 关键词:全卷积神经网络.平稳小波变换.视网膜眼底图像.血管分割.深度学习 摘要 本文提出了一种新的方法,将平稳小波变换提供的多尺度分析与多尺度全卷积神经网络相结合, ...

  2. 【论文阅读】Learning Traffic as Images: A Deep Convolutional ... [将交通作为图像学习: 用于大规模交通网络速度预测的深度卷积神经网络](2)

    [论文阅读]Learning Traffic as Images: A Deep Convolutional Neural Network for Large-Scale Transportation ...

  3. 【论文阅读】Learning Traffic as Images: A Deep Convolutional ... [将交通作为图像学习: 用于大规模交通网络速度预测的深度卷积神经网络](1)

    [论文阅读]Learning Traffic as Images: A Deep Convolutional Neural Network for Large-Scale Transportation ...

  4. 【论文阅读】Learning Spatiotemporal Features with 3D Convolutional Networks

    [论文阅读]Learning Spatiotemporal Features with 3D Convolutional Networks 这是一篇15年ICCV的论文,本篇论文提出的C3D卷积网络是 ...

  5. 论文阅读06——《CaEGCN: Cross-Attention Fusion based Enhanced Graph Convolutional Network for Clustering》

    欢迎到我的个人博客看原文 论文阅读06--<CaEGCN: Cross-Attention Fusion based Enhanced Graph Convolutional Network f ...

  6. 【论文阅读】Learning Spatio-Temporal Representation with Pseudo-3D Residual Networks

    [论文阅读]Learning Spatio-Temporal Representation with Pseudo-3D Residual Networks 虽然这是一篇17年ICCV的论文,但是这篇 ...

  7. 【论文阅读】Learning Semantically Enhanced Feature for Fine-Grained Image Classification

    [论文阅读] Learning Semantically Enhanced Feature for Fine-Grained Image Classification 摘要 具体实现 语义分组模块 特 ...

  8. 论文阅读|struc2vec: Learning Node Representations from Structural Identity

    论文阅读|struc2vec: Learning Node Representations from Structural Identity 文章目录 论文阅读|struc2vec: Learning ...

  9. 【论文阅读】22-GMS: Grid-based Motion Statistics for Fast, Ultra-robust Feature Correspondence

    [论文阅读]22-GMS: Grid-based Motion Statistics for Fast, Ultra-robust Feature Correspondence 0.basic inf ...

最新文章

  1. 华谊兄弟出现什么问题_曾经的影视龙头一哥华谊兄弟,为什么如今混得那么惨?...
  2. 2019最新版本的PanDownload纯净版,网盘满速下载和搜索神器,追剧和动漫新番必不可少的下载工具【亲测有效】
  3. (iOS)Storyboard/xib小技巧
  4. Elasticsearch+Kibana 设置连接密码
  5. python3网络爬虫(4):python3安装Scrapy
  6. java enum分析
  7. 找出数组中第i小元素(时间复杂度Θ(n)--最坏情况为线性的选择算法
  8. linux /proc 详解
  9. 【数据库基础知识】数据库表格——主键和外键
  10. C语言排序方法-----二分插入排序
  11. jsonready onload 与_漫谈JSONP以及img的onLoad和onEr
  12. python约瑟夫环_Python语言之如何实现约瑟夫环问题
  13. 维恩图是什么?如何使用维恩图?
  14. 两用图片视频压缩软件
  15. 网络对抗 Exp5 MSF基础应用 20154311 王卓然
  16. 常用hadoop dfs命令
  17. Apache端口占用解决办法
  18. [写Bug记录] Maven出现 Library xxx has broken classes paths
  19. 语音助手——垂类永动机——自动化迭代框架
  20. 数据传输速率(传码速率和传信速率)

热门文章

  1. Linux定时器jiffies学习
  2. 不用 Flash 观看 bilibili 直播
  3. HTML爱心网页制作[樱花+爱心+炫彩文字]
  4. 弱网测试----苹果手机
  5. Android面试必问的Activity,初阶,中高阶问法,你都掌握了吗?(要求熟读并背诵全文)
  6. java泛型类的作用_【Java-泛型系列一-泛型的作用】
  7. 曙光天阔H系列服务器,曙光天阔 A620r- H服务器系统支持 AMD最新推出的
  8. 安踏跑步发布全新氢跑鞋ZERO,挑战世界纪录的轻
  9. netcraft 查询网络数据结构
  10. 五步教你改变窗体背景色