deepfillv2的动机

​ 结合了几乎所有的目前先进的图像修复技术,基于部分卷积提出了门控卷积,结合了CA中的注意力机制,根据 Adversarial Edge图像修复中的边缘信息先验提出了用户可交互的草图先验信息。基于spectral-normallized GAN 提出了 SN-PatchGAN 鉴别器,本文所用的损失函数只有l1 重建损失和 SN-PatchGAN损失.

1. Gated Convolution

为了介绍门控卷积,得先提提部分卷积,对于分类、分割等任务,网络的输入像素是全部有效的,而对于修复任务,孔洞区域的像素是无效像素,如果将其当成和其他区域的像素一样处理,那么必然会造成修复结果的模糊,颜色不一致等情况,基于这种原因,部分卷积(partial convolution)被提出。它的实现机制在我上一篇Image Inpainting for Irregular Holes Using有被提到。它的目的在于,使得卷积的结果尽量只依赖与有效像素。部分卷积有效提高了非规则掩模上的图像修复质量。但是仍然还存在一些问题:

  1. 在跟新mask时,它启发式地将所有空间位置分类为有效或无效。无论前一层的过滤范围覆盖了多少像素,下一层的掩码都将被设置为1(例如,1个有效的像素和9个有效的像素被当作相同的来更新当前的掩码),这样显得不太合理。
  2. 如果模型是需要与用户进行交互的,那么用户输入的稀疏草图掩模作为条件通道。在这种情况下,应该认为这些像素位置是有效的还是无效的?如何正确更新下一层的mask
  3. 对于部分卷积来说,如果网络加深到一定程度那么mask最终会被全部更新为1(即全部都是有效像素),本文的作者提出应该让网络自动学习最优的掩码,网络将软掩码值分配给每个空间位置
  4. 部分卷积中,每个层中的所有通道都共享同一个掩码mask,这限制了灵活性。本质上,部分卷积可以被视为不可学习的单通道特征硬门控。

部分卷积与门控卷积的图示区别如下图:

基于上述部分卷积的一些问题,本文作者提出了门控卷积。取代了部分卷积的硬门控的掩码mask更新规则,门控卷积从数据中自动学习软掩码mask.更新的数学表达如下:

这里的I是特征图,σ\sigmaσ是sigmoid()函数,ϕ\phiϕ是激活函数,可以是ReLU、ELU、LeakyReLU。实际就是对I分别做两次卷积,然后其中一个卷积用sigmoid()函数,将其值全部限制在0-1之间,然后与另外一个卷积得到的特征图进行逐像素的相乘。

门控卷积的代码实现非常简单,如下:

#1.门控卷积的模块
class Gated_Conv(nn.Module):def __init__(self,in_ch,out_ch,ksize=3,stride=1,rate=1,activation=nn.ELU):super(Gated_Conv, self).__init__()padding=int(rate*(ksize-1)/2)#通过卷积将通道数变成输出两倍,其中一半用来做门控,学习self.conv=nn.Conv2d(in_ch,2*out_ch,kernel_size=ksize,stride=stride,padding=padding,dilation=rate)self.activation=activationdef forward(self,x):raw=self.conv(x)x1=raw.split(int(raw.shape[1]/2),dim=1)#将特征图分成两半,其中一半是做学习gate=torch.sigmoid(x1[0])#将值限制在0-1之间out=self.activation(x1[1])*gatereturn out

2. SN-PatchGAN

对于孔洞单一为矩形的,local GAN 使用提升了修复结果,但是对于自由形式孔洞区域,这种局部鉴别器显然不太适用。基于 global and local GANs、MarkovianGAN、perceptual loss 和spectral-normalized loss.。作者提出了简单高效的SN-PatchGAN,可以应对自由形式的空洞破损。网络结构如下图所示:

网络的输入包括:破损图片、孔洞掩码mask、用户指导的先验草图信息。网络的输出是3D的feature map.而不是传统鉴别器输出的了一个打分标量。网络堆叠了6个卷积为kernel size为5,stride=2去捕获Markovian patches的特征统计信息。值得注意的是输出特征图的每一个元素的感受野都是包含了整个输入图。因此全局鉴别器也就不需要了。同时也采用了spectral normalizetion (借鉴的是SN-GANs)来进一步稳定GAN的训练。为了鉴别出真图还是假图,采用了hinge loss作为目标函数,对于生成器G:
lossG=−Ez−pz(z)[Dsn(G(z))]loss_G=-E_{z-p_z(z)}[D^{sn}(G(z))] lossG​=−Ez−pz​(z)​[Dsn(G(z))]
对于鉴别器:
lossD=Ex−Pdata(x)[ReLU(1−Dsn(x))]+Ez−pz(z)[ReLU(1+Dsn(G(z)))]loss_D=E_{x-P_{data}(x)}[ReLU(1-D^{sn}(x))]+E_{z-p_z(z)}[ReLU(1+D^{sn}(G(z)))] lossD​=Ex−Pdata​(x)​[ReLU(1−Dsn(x))]+Ez−pz​(z)​[ReLU(1+Dsn(G(z)))]
这里的DsnD^{sn}Dsn代表spectral-normalized discriminator ,G是修复网络。

鉴别器网络结构实现如下:

#1.
class SpectralNorm(nn.Module):'''spectral normalization,modified from https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py'''def __init__(self, module, name='weight', power_iterations=1):super(SpectralNorm, self).__init__()self.module = moduleself.name = nameself.power_iterations = power_iterationsif not self._made_params():self._make_params()def _update_u_v(self):u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")height = w.data.shape[0]for _ in range(self.power_iterations):v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))sigma = u.dot(w.view(height, -1).mv(v))setattr(self.module, self.name, w / sigma.expand_as(w))def _made_params(self):try:u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")return Trueexcept AttributeError:return Falsedef _make_params(self):w = getattr(self.module, self.name)height = w.data.shape[0]width = w.view(height, -1).data.shape[1]u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)u.data = l2normalize(u.data)v.data = l2normalize(v.data)w_bar = Parameter(w.data)del self.module._parameters[self.name]self.module.register_parameter(self.name + "_u", u)self.module.register_parameter(self.name + "_v", v)self.module.register_parameter(self.name + "_bar", w_bar)def forward(self, *args):self._update_u_v()return self.module.forward(*args)#2.SN卷积层实现class SN_Conv(nn.Module):def __init__(self,in_ch,out_ch,ksize=3,stride=1,rate=1,activation=nn.LeakyReLU()):super(SN_Conv,self).__init__()padding = int(rate * (ksize - 1) / 2)conv = nn.Conv2d(in_ch,out_ch, kernel_size=ksize, stride=stride, padding=padding, dilation=rate)self.snconv = SpectralNorm(conv)self.activation = activationdef forward(self,x):x1 = self.snconv(x)if self.activation is not None:x1 = self.activation(x1)return x1#3.sn鉴别器网络
class SNDiscriminator(nn.Module):def __init__(self,in_ch=5,cnum=64):super(SNDiscriminator,self).__init__()disconv_layer = OrderedDict()disconv_layer['conv1'] = SN_Conv(in_ch=in_ch,out_ch=cnum,ksize=5,stride=2)disconv_layer['conv2'] = SN_Conv(in_ch=cnum, out_ch=2*cnum, ksize=5, stride=2)disconv_layer['conv3'] = SN_Conv(in_ch=2*cnum, out_ch=4*cnum, ksize=5, stride=2)disconv_layer['conv4'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)disconv_layer['conv5'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)disconv_layer['conv6'] = SN_Conv(in_ch=4 * cnum, out_ch=4 * cnum, ksize=5, stride=2)self.dislayer = nn.Sequential(disconv_layer)def forward(self,x):x1 = self.dislayer(x)#print(x1.shape)out = x1.view(x1.shape[0],-1)return out

3. inpainting Network Architecture

整个修复网络分为两个阶段(粗阶段和细化阶段),卷积部分都采用了门控卷积:

#1.粗阶段,输入是5通道(破损图片3,掩码mask,用户指导草图),输出为3通道
class CoarseNet(nn.Module):def __init__(self,in_ch=5,cnum=48):super(CoarseNet,self).__init__()self.conv1 = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)self.conv2_down = Gated_Conv(in_ch=cnum,out_ch=2*cnum,stride=2)self.conv3 = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum)self.conv4_down = Gated_Conv(in_ch=2*cnum,out_ch=4*cnum,stride=2)self.conv5 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)self.conv6 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)self.conv7 = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum,rate=2)self.conv8 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=4)self.conv9 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=8)self.conv10 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=16)self.conv11 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)self.conv12 = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)self.conv13_up = Gated_Deconv(in_ch=4*cnum,out_ch=2*cnum)self.conv14 = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum)self.conv15_up = Gated_Deconv(in_ch=2*cnum,out_ch=cnum)self.conv16 = Gated_Conv(in_ch=cnum,out_ch=cnum//2)self.conv17 = nn.Conv2d(in_channels=cnum//2,out_channels=3,kernel_size=3,stride=1,padding=1)def forward(self,x):x1 = self.conv1(x)x2 = self.conv2_down(x1)x3 = self.conv3(x2)x4 = self.conv4_down(x3)x5 = self.conv5(x4)x6 = self.conv6(x5)x7 = self.conv7(x6)x8 = self.conv8(x7)x9 = self.conv9(x8)x10 = self.conv10(x9)x11 = self.conv11(x10)x12 = self.conv12(x11)x13 = self.conv13_up(x12)x14 = self.conv14(x13)x15 = self.conv15_up(x14)x16 = self.conv16(x15)x17 = self.conv17(x16)x_stage1 = F.tanh(x17)return x_stage1#2,细化阶段的输入为粗阶段的输出结果,该阶段有两个分支(卷积分支和注意力机制分支)
class RefineNet(nn.Module):def __init__(self,in_ch=3,cnum=48):super(RefineNet,self).__init__()#1.conv branchxconv_layer = OrderedDict()xconv_layer['xconv1'] = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)xconv_layer['xconv2_down'] = Gated_Conv(in_ch=cnum,out_ch=cnum,stride=2)xconv_layer['xconv3'] =  Gated_Conv(in_ch=cnum,out_ch=2*cnum)xconv_layer['xconv4_down'] = Gated_Conv(in_ch=2*cnum,out_ch=2*cnum,stride=2)xconv_layer['xconv5'] = Gated_Conv(in_ch=2*cnum,out_ch=4*cnum)xconv_layer['xconv6'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)xconv_layer['xconv7_atrous']  = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum,rate=2)xconv_layer['xconv8_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=4)xconv_layer['xconv9_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=8)xconv_layer['xconv10_atrous'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum, rate=16)self.xlayer = nn.Sequential(xconv_layer)#2.attention brachpmconv_layer1 = OrderedDict()pmconv_layer1['pmconv1'] = Gated_Conv(in_ch=in_ch,out_ch=cnum,ksize=5)pmconv_layer1['pmconv2_down'] = Gated_Conv(in_ch=cnum,out_ch=cnum,stride=2)pmconv_layer1['pmconv3'] = Gated_Conv(in_ch=cnum,out_ch=2*cnum)pmconv_layer1['pmconv4_down'] = Gated_Conv(in_ch=2*cnum, out_ch=4*cnum, stride=2)pmconv_layer1['pmconv5'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)pmconv_layer1['pmconv6'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum,activation=nn.ReLU())self.pmlayer1 = nn.Sequential(pmconv_layer1)self.CA = Contextual_Attention(rate=2)pmconv_layer2 = OrderedDict()pmconv_layer2['pmconv9'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)pmconv_layer2['pmconv10'] = Gated_Conv(in_ch=4*cnum,out_ch=4*cnum)self.pmlayer2 = nn.Sequential(pmconv_layer2)#confluent branchallconv_layer = OrderedDict()allconv_layer['allconv11'] = Gated_Conv(in_ch=8*cnum,out_ch=4*cnum)allconv_layer['allconv12'] = Gated_Conv(in_ch=4 * cnum, out_ch=4 * cnum)allconv_layer['allconv13_up'] = Gated_Deconv(in_ch=4 * cnum, out_ch=2 * cnum)allconv_layer['allconv14'] = Gated_Conv(in_ch=2 * cnum, out_ch=2 * cnum)allconv_layer['allconv15_up'] = Gated_Deconv(in_ch=2 * cnum, out_ch=cnum)allconv_layer['allconv16'] = Gated_Conv(in_ch=cnum, out_ch=cnum//2)allconv_layer['allconv17'] = nn.Conv2d(in_channels=cnum//2,out_channels=3,kernel_size=3,padding=1)allconv_layer['tanh'] = nn.Tanh()self.colayer = nn.Sequential(allconv_layer)def forward(self, xin, mask):x1 = self.xlayer(xin)x_hallu = x1x2 = self.pmlayer1(xin)mask_s = self.resize_mask_like(mask,x2)x3,offset_flow = self.CA(x2,x2,mask_s)x4 = self.pmlayer2(x3)pm = x4x5 = torch.cat((x_hallu,pm),dim=1)x6 = self.colayer(x5)x_stage2 = x6return x_stage2,offset_flowdef resize_mask_like(self,mask,x):sizeh = x.shape[2]sizew = x.shape[3]return down_sample(mask,size=(sizeh,sizew))#3.完整的修复网络
class CAGenerator(nn.Module):def __init__(self,in_ch=5,cnum=48,):super(CAGenerator,self).__init__()self.stage_1 = CoarseNet(in_ch=in_ch,cnum=cnum)self.stage_2 = RefineNet(in_ch=3,cnum=cnum)def forward(self,xin,mask):stage1_out = self.stage_1(xin)stage2_in = stage1_out * mask + xin[:,0:3,:,:] * (1. - mask)stage2_out,offset_flow = self.stage_2(stage2_in,mask)return stage1_out,stage2_out,offset_flow

4.总结

作者提出了一种基于端到端生成网络的新型自由形式图像修复系统,该网络具有门控卷积,并经过逐像素l1损失和SN-patchGAN训练。证明门控卷积改善了修复的质量。

参考文献

1.Free-Form Image Inpainting with Gated Convolution(ICCV2019)

2.https://github.com/KeyKy/generative-inpainting-2.0-pytorch

(门控卷积实现)DeepFillv2(图像修复):Free-Form Image Inpainting with Gated Convolution,pytroch代码实现相关推荐

  1. 【深度学习】精度超越 ConvNeXt 的新 CNN!HorNet:通过递归门控卷积实现高效高阶的空间信息交互...

    作者丨科技猛兽    编辑丨极市平台 本文提出了一种基于递归的门控卷积的通用视觉模型,是来自清华大学周杰老师,鲁继文老师团队,以及 Meta AI 的学者们在通用视觉模型方面有价值的探索. 本文目录 ...

  2. 精度超越ConvNeXt的新CNN——HorNet:通过递归门控卷积实现高效高阶的空间信息交互

    作者|科技猛兽 编辑|3D视觉开发者社区 本文目录 HorNet:通过递归门控卷积实现高效高阶的空间信息交互 1.1 HorNet 原理分析 1.1.1 背景和动机 1.1.2 HorNet 简介 1 ...

  3. 递归门控卷积HorNet(gn_conv)阅读笔记

    HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions ECCV2022 程序 视觉 T ...

  4. YOLOv7改进之二十二:涨点神器——引入递归门控卷积(gnConv)

     ​前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法.此后的系列文章,将重点对YOLOv ...

  5. yolov5创新 C3GN:引荐HorNet递归门控卷积GnConv重构目标检测颈部网络

    yolov5创新 C3GN:引荐HorNet递归门控卷积GnConv重构目标检测颈部网络 1.引荐HorNet递归门控卷积思想 论文地址: https://arxiv.org/pdf/2207.142 ...

  6. 改进YOLOv5系列:10.最新HorNet结合YOLO应用首发! | 多种搭配,即插即用 | Backbone主干、递归门控卷积的高效高阶空间交互高效

  7. 结合深度学习的图像修复怎么实现?

    作者:QZhang 链接:https://www.zhihu.com/question/56801298/answer/155891603 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非 ...

  8. 基于深度学习的Image Inpainting (图像修复)论文推荐(持续更新)

    传统的图形学和视觉的研究方法,主要还是基于数学和物理的方法.然而随着近几年深度学习在视觉领域取得的卓越的效果,视觉领域研究的前沿已经基本被深度学习占领.在这样的形势之下,越来越多的图形学研究者也开始将 ...

  9. 结构感知图像修复:ICCV2019论文解析

    结构感知图像修复:ICCV2019论文解析 StructureFlow: Image Inpainting via Structure-aware Appearance Flow 论文链接: http ...

最新文章

  1. 虽然现在没有闲也没有钱,还是建立了自己的BLOG,因为心里很痒
  2. pu learning的建模实践,半监督学习的好方法!
  3. get与post请求问题
  4. boost::movelib::unique_ptr相关用法的测试程序
  5. eclipse没有日志_技术进展 | 加强公共DHT抵抗eclipse攻击!
  6. maven下载包慢解决
  7. Linux下做一个arp欺骗程序6,LINUX下防ARP欺骗攻击
  8. thinkphp 分页出错 $page-render() 出错
  9. NSArray和NSString的联合使用
  10. 【Python-2.7】删除空格
  11. iOS 警告收录及科学快速的消除方法
  12. 下载eclipse太慢怎么办?
  13. 计量经济学(七)----自相关性Autocorrelation.
  14. 浪漫的表白(C语言)
  15. 一文读懂等级保护二级
  16. 基于java的婚恋交友动态网站
  17. android media player实现一个可手势滑动控制 + 可以调节分辨率|字幕|倍速的视频播放器(MediaPlayer + ExoPlayer实现)
  18. 在手机上图片分辨率怎么调?怎样用手机改300dpi图片?
  19. 你知道数据在内存中是如何存储的嘛?
  20. 一步搞定无法审查元素

热门文章

  1. 中国联通与阿里云达成合作,推动5G+新媒体产业发展
  2. opencv VideoCapture抓取RTSP高延迟,崩溃解决方法
  3. 【Lifelong learning】Efficient Meta Lifelong-Learning with Limited Memory
  4. vue(17) : leaflet(2) : 距离测量
  5. c语言英文信件怎么能,撰写高效英文商务邮件的方法
  6. 众创新企业落地物联网,物联网助阵台湾科技产业升级
  7. 新编剑桥商务英语初级第三版 答案
  8. 宁波大学计算机考研资料汇总
  9. FinalShell 介绍
  10. 【c#】Fedex官方API对接过程