PSGAN 网络再修改

在测试了PSGAN的效果之后,感觉对于不少测试而言效果很好,但还有一些痛点,比如当参考图是侧脸时,左右眼由于光影的原因在妆容的浓淡上会有较大差异,这也会反应在原图是正脸的左右眼上。

如下所示:


PSGAN中计算attention map是用的像素相对距离Diff_A和diff_B造成的一点不好的影响就是:右图嘴唇部分确实和左图非常像,但嘴唇左右边缘处没有涂上口红,可能是给relative_position的attention weight太大了。

如下所示:

基于上面的问题,我对网络结构和训练代码进行了一些修改,训练计算makeup loss时不再将左右眼的makeup分开计算,而是同时计算双眼的histogram makeup loss;对于网络结构,直接通过两张feature map计算attention map,先训练一下看看结果有无改善。

网络结构代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom ops.spectral_norm import spectral_norm as SpectralNorm# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the inputclass ResidualBlock(nn.Module):"""Residual Block."""def __init__(self, dim_in, dim_out):super(ResidualBlock, self).__init__()self.main = nn.Sequential(nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),nn.InstanceNorm2d(dim_out, affine=True),nn.ReLU(inplace=True),nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),nn.InstanceNorm2d(dim_out, affine=True))def forward(self, x):return x + self.main(x)class Discriminator(nn.Module):"""Discriminator. PatchGAN."""def __init__(self, image_size=128, conv_dim=64, repeat_num=3, norm='SN'):super(Discriminator, self).__init__()layers = []if norm == 'SN':layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))else:layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))layers.append(nn.LeakyReLU(0.01, inplace=True))curr_dim = conv_dimfor i in range(1, repeat_num):if norm == 'SN':layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))else:layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))layers.append(nn.LeakyReLU(0.01, inplace=True))curr_dim = curr_dim * 2# k_size = int(image_size / np.power(2, repeat_num))if norm == 'SN':layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))else:layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))layers.append(nn.LeakyReLU(0.01, inplace=True))curr_dim = curr_dim * 2self.main = nn.Sequential(*layers)if norm == 'SN':self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))else:self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)# conv1 remain the last square size, 256*256-->30*30# self.conv2 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=k_size, bias=False))# conv2 output a single numberdef forward(self, x):h = self.main(x)out_makeup = self.conv1(h)return out_makeup.squeeze()class VGG(nn.Module):def __init__(self, pool='max'):super(VGG, self).__init__()# vgg modulesself.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)if pool == 'max':self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)elif pool == 'avg':self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)def forward(self, x, out_keys):out = {}out['r11'] = F.relu(self.conv1_1(x))out['r12'] = F.relu(self.conv1_2(out['r11']))out['p1'] = self.pool1(out['r12'])out['r21'] = F.relu(self.conv2_1(out['p1']))out['r22'] = F.relu(self.conv2_2(out['r21']))out['p2'] = self.pool2(out['r22'])out['r31'] = F.relu(self.conv3_1(out['p2']))out['r32'] = F.relu(self.conv3_2(out['r31']))out['r33'] = F.relu(self.conv3_3(out['r32']))out['r34'] = F.relu(self.conv3_4(out['r33']))out['p3'] = self.pool3(out['r34'])out['r41'] = F.relu(self.conv4_1(out['p3']))out['r42'] = F.relu(self.conv4_2(out['r41']))out['r43'] = F.relu(self.conv4_3(out['r42']))out['r44'] = F.relu(self.conv4_4(out['r43']))out['p4'] = self.pool4(out['r44'])out['r51'] = F.relu(self.conv5_1(out['p4']))out['r52'] = F.relu(self.conv5_2(out['r51']))out['r53'] = F.relu(self.conv5_3(out['r52']))out['r54'] = F.relu(self.conv5_4(out['r53']))out['p5'] = self.pool5(out['r54'])return [out[key] for key in out_keys]# Makeup Apply Network(MANet)
class Generator(nn.Module):"""Generator. Encoder-Decoder Architecture."""def __init__(self, conv_dim=64, repeat_num=6):super(Generator, self).__init__()encoder_layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),nn.InstanceNorm2d(conv_dim, affine=False), nn.ReLU(inplace=True)]# MANet设置没有affine# Down-Samplingcurr_dim = conv_dimfor i in range(2):encoder_layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))encoder_layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=False))encoder_layers.append(nn.ReLU(inplace=True))curr_dim = curr_dim * 2# Bottleneckfor i in range(3):encoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))decoder_layers = []for i in range(3):decoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))# Up-Samplingfor i in range(2):decoder_layers.append(nn.ConvTranspose2d(curr_dim, curr_dim // 2, kernel_size=4, stride=2, padding=1, bias=False))decoder_layers.append(nn.InstanceNorm2d(curr_dim // 2, affine=True))decoder_layers.append(nn.ReLU(inplace=True))curr_dim = curr_dim // 2decoder_layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))decoder_layers.append(nn.Tanh())self.encoder = nn.Sequential(*encoder_layers)self.decoder = nn.Sequential(*decoder_layers)self.MDNet = MDNet()self.AMM = AMM()def forward(self, source_image, reference_image, mask_source, mask_ref, gamma=None, beta=None, ret=False,mode='train'):fm_source = self.encoder(source_image)fm_reference = self.MDNet(reference_image)if ret:gamma, beta = self.AMM(fm_source, fm_reference, mask_source, mask_ref, gamma=gamma, beta=beta, ret=ret,mode=mode)return [gamma, beta]morphed_fm = self.AMM(fm_source, fm_reference, mask_source, mask_ref, gamma=gamma, beta=beta, ret=ret,mode=mode)result = self.decoder(morphed_fm)return resultclass MDNet(nn.Module):"""Generator. Encoder-Decoder Architecture."""# MDNet is similar to the encoder of StarGANdef __init__(self, conv_dim=64, repeat_num=3):super(MDNet, self).__init__()layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),nn.InstanceNorm2d(conv_dim, affine=True), nn.ReLU(inplace=True)]# Down-Samplingcurr_dim = conv_dimfor i in range(2):layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=True))layers.append(nn.ReLU(inplace=True))curr_dim = curr_dim * 2# Bottleneckfor i in range(repeat_num):layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))self.main = nn.Sequential(*layers)def forward(self, reference_image):fm_reference = self.main(reference_image)return fm_referenceclass AMM(nn.Module):"""Attentive Makeup Morphing module"""def __init__(self):super(AMM, self).__init__()self.gamma_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)self.beta_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)def forward(self, fm_source, fm_reference, mask_source, mask_ref, gamma=None, beta=None, ret=False, mode='train'):old_gamma_matrix = self.gamma_matrix_conv(fm_reference)old_beta_matrix = self.beta_matrix_conv(fm_reference)old_gamma_matrix_source = self.gamma_matrix_conv(fm_source)old_beta_matrix_source = self.beta_matrix_conv(fm_source)if gamma is None:attention_map = self.raw_attention_map(fm_source, fm_reference)# gamma, beta = self.raw_atten_feature(mask_source, attention_map, old_gamma_matrix, old_beta_matrix,#                                      old_gamma_matrix_source, old_beta_matrix_source)gamma, beta = self.pure_atten_feature(attention_map, old_gamma_matrix, old_beta_matrix)if ret:return [gamma, beta]morphed_fm_source = fm_source * (1 + gamma) + betareturn morphed_fm_source@staticmethoddef raw_attention_map(fm_source, fm_reference):batch_size, channels, width, height = fm_reference.size()# reshape后fm的形状是C*(H*W)temp_fm_reference = fm_reference.view(batch_size, -1, height * width)# fm_source 在reshape后需要transpose成(H*W)*Ctemp_fm_source = fm_source.view(batch_size, -1, height * width).permute(0, 2, 1)# energy的形状应该是N*N,N=H*Wenergy = torch.bmm(temp_fm_source, temp_fm_reference)energy *= 200  # hyper parameters for visual featureattention_map = F.softmax(energy, dim=-1)return attention_map@staticmethoddef raw_atten_feature(mask_source, attention_map, old_gamma_matrix, old_beta_matrix, old_gamma_matrix_source,old_beta_matrix_source):batch_size, channels, width, height = old_gamma_matrix.size()old_gamma_matrix = old_gamma_matrix.view(batch_size, -1, width * height)old_beta_matrix = old_beta_matrix.view(batch_size, -1, width * height)new_gamma_matrix = torch.bmm(old_gamma_matrix, attention_map.permute(0, 2, 1))new_beta_matrix = torch.bmm(old_beta_matrix, attention_map.permute(0, 2, 1))new_gamma_matrix = new_gamma_matrix.view(-1, 1, width, height)new_beta_matrix = new_beta_matrix.view(-1, 1, width, height)reverse_mask_source = 1 - mask_sourcenew_mask_source = F.interpolate(mask_source, size=new_gamma_matrix.shape[2:]).repeat(1, channels, 1, 1)new_reverse_mask_source = F.interpolate(reverse_mask_source, size=new_gamma_matrix.shape[2:]).repeat(1,channels,1, 1)gamma = new_gamma_matrix * new_mask_source + old_gamma_matrix_source * new_reverse_mask_sourcebeta = new_beta_matrix * new_mask_source + old_beta_matrix_source * new_reverse_mask_sourcereturn gamma, beta# 只通过计算两个feature map的attention来修改gamma_matrix@staticmethoddef pure_atten_feature(attention_map, old_gamma_matrix, old_beta_matrix):batch_size, channels, width, height = old_gamma_matrix.size()old_gamma_matrix = old_gamma_matrix.view(batch_size, -1, width * height)old_beta_matrix = old_beta_matrix.view(batch_size, -1, width * height)new_gamma_matrix = torch.bmm(old_gamma_matrix, attention_map.permute(0, 2, 1))new_beta_matrix = torch.bmm(old_beta_matrix, attention_map.permute(0, 2, 1))new_gamma_matrix = new_gamma_matrix.view(-1, 1, width, height)new_beta_matrix = new_beta_matrix.view(-1, 1, width, height)gamma = new_gamma_matrixbeta = new_beta_matrixreturn gamma, beta# 下面是PSGAN中计算attention的方法,但需要的显存太大,而且感觉有一些不合理的地方@staticmethoddef get_attention_map(mask_source, mask_ref, fm_source, fm_reference, mode='train'):HW = 64 * 64batch_size = 3# get 3 part fea using maskchannels = fm_reference.shape[1]mask_source_re = F.interpolate(mask_source, size=64).repeat(1, channels, 1, 1)  # (3, c, h, w)fm_source = fm_source.repeat(3, 1, 1, 1)  # (3, c, h, w)# 计算 Attention 时 we only consider the pixels belonging to same facial region.fm_source = fm_source * mask_source_re  # (3, c, h, w) 3 stands for 3 partsmask_ref_re = F.interpolate(mask_ref, size=64).repeat(1, channels, 1, 1)fm_reference = fm_reference.repeat(3, 1, 1, 1)fm_reference = fm_reference * mask_ref_retheta_input = fm_sourcephi_input = fm_referencetheta_target = theta_input.view(batch_size, -1, HW)theta_target = theta_target.permute(0, 2, 1)phi_source = phi_input.view(batch_size, -1, HW)weight = torch.bmm(theta_target, phi_source)  # (3, HW, HW)if mode == 'train':weight = weight.cpu()weight_ind = torch.LongTensor(weight.detach().numpy().nonzero())weight = weight.cuda()weight_ind = weight_ind.cuda()else:weight_ind = torch.LongTensor(weight.numpy().nonzero())weight *= 200  # hyper parameters for visual featureweight = F.softmax(weight, dim=-1)weight = weight[weight_ind[0], weight_ind[1], weight_ind[2]]return torch.sparse.FloatTensor(weight_ind, weight, torch.Size([3, HW, HW]))@staticmethoddef atten_feature(mask_ref, attention_map, old_gamma_matrix, old_beta_matrix):# 论文中有说gamma和beta的想法源于style transfer,但不是general style transfer,所以这里要用mask计算每个facial region的stylebatch_size, channels, width, height = old_gamma_matrix.size()mask_ref_re = F.interpolate(mask_ref, size=old_gamma_matrix.shape[2:]).repeat(1, channels, 1, 1)gamma_ref_re = old_gamma_matrix.repeat(3, 1, 1, 1)old_gamma_matrix = gamma_ref_re * mask_ref_re  # (3, c, h, w)beta_ref_re = old_beta_matrix.repeat(3, 1, 1, 1)old_beta_matrix = beta_ref_re * mask_ref_reold_gamma_matrix = old_gamma_matrix.view(3, 1, -1)old_beta_matrix = old_beta_matrix.view(3, 1, -1)old_gamma_matrix = old_gamma_matrix.permute(0, 2, 1)old_beta_matrix = old_beta_matrix.permute(0, 2, 1)new_gamma_matrix = torch.bmm(attention_map.to_dense(), old_gamma_matrix)new_beta_matrix = torch.bmm(attention_map.to_dense(), old_beta_matrix)gamma = new_gamma_matrix.view(-1, 1, width, height)  # (3, c, h, w)beta = new_beta_matrix.view(-1, 1, width, height)gamma = (gamma[0] + gamma[1] + gamma[2]).unsqueeze(0)  # (c, h, w) combine the three partsbeta = (beta[0] + beta[1] + beta[2]).unsqueeze(0)return gamma, beta

PSGAN 网络再修改相关推荐

  1. 网络数据修改工具netsed

    网络数据修改工具netsed 通过修改网络数据,可以绕过软件和防火墙的限制,达到特定的目的.Kali Linux提供一个简易数据修改工具netsed.该工具支持对TCP和UDP的数据进行修改.渗透测试 ...

  2. vnc改ip_苹果电脑远程控制设置screen sharing(VNC)网络端口修改

    苹果电脑远程控制设置screen sharing(VNC)网络端口修改事前需要先确定被远程的mac网络是公网ip 如果不需要改远程端口的请从第6步开始打开launchpad找到others--> ...

  3. 2007年9月c语言真题及答案,2007年9月二级C语言笔试真题和答案(已再修改).doc

    2007年9月二级C语言笔试真题和答案(已再修改) 2007年9月二级C语言笔试真题及答案 (考试时间:120分钟,满分100分) 一.选择题((1)-(10)每题2分,(11)-(50)每题1分.共 ...

  4. 计算机改网络id,Windows 8普通版移除“网络ID”修改功能

    [PConline 资讯]5月7日消息,消息称微软将在今年6月的第一周公布Windows 8 Release Preview版本,目前Windows 8 Release Preview的最新版本是83 ...

  5. 快速查找MySQL数据库中表编码字符集,再修改为指定字符集

    快速查找MySQL数据库中表编码字符集,再修改为指定字符集 SELECT CONCAT(' ALTER TABLE ',TABLE_NAME ,' CONVERT TO CHARACTER SET u ...

  6. [I.MX6UL] U-Boot移植(六) 网络驱动修改 LAN8720A(对比原子和NXP官方测试板的网络芯片LAN8720A , KSZ8081 (也是飞凌)唯独复位引脚不同595芯片也涉及改动)

    I.MX6UL/ULL 内部有个以太网 MAC 外设,也就是 ENET,需要外接一个 PHY 芯片来实现网络通信功能,也就是内部MAC+外部 PHY 芯片的方案. I.MX6UL/ULL 有两个网络接 ...

  7. android模拟器dns,网络异常,显示无网络(修改DNS方法)

    纵观全网,有许多使用夜神小伙伴反馈说,在安卓模拟器更新游戏或在线更新文件的时候无法更新,会出现"更新失败.网络异常.浏览器显示无网络"等情况,可以戳应用无法联网.网络异常的解决办法 ...

  8. 自己动手搭建Fabric网络,修改当前工作目录名之后出现的错误

    1.当前工作区 可以看到当前工作区目录名为itcast,此时fabric网络能够正常启动 2. 修改当前工作区目录名为itcast2,再次启动网络 接下来就报错误:Cannot create cont ...

  9. Linux网络管理,NAT网络配置,修改主机名称,主机映射,防火墙,系统启动级别,用户和用户组,为用户配置sudoer权限,文件权限管理,打包和压缩

    目录   1.常用网络管理命令   2.网络配置   3.修改主机名称   4.主机映射   5.防火墙   6.系统启动级别   7.用户和用户组   8.为用户配置sudoer权限   9.文件权 ...

最新文章

  1. 【Python基础】Python的元组,没想象的那么简单
  2. dos命令操作mysql数据库的常用语句
  3. 特斯拉宣布将在欧洲建设第二座超级工厂
  4. Web开发入门疑问收集(不定期更新)
  5. codeforces 546A-C语言解题报告
  6. win7 更改IP 脚本(自动获取和手动设置多个IP),将里面内容拷贝到记事本另存为set_win7_IP.bat
  7. vb一个使用URLDownloadToFile实现文件下载的类
  8. java 包结构 枚举类_Java日期时间API系列6-----Jdk8中java.time包中的新的日期时间API类...
  9. 计算机网络 以太网 和令牌环网
  10. 三十岁左右的你,现在收入多少?
  11. 各式各样的计算机教学设计,7. 各式各样的椅子教案设计(一等奖)
  12. 项目管理软件用哪个好?推荐这6款项目管理工具
  13. 分布式Ruby解决之道
  14. 虚拟服务器声卡,怎么在Win7系统Hyper-v虚拟机中接真实机声卡
  15. linux barrier,Linux文件系统的barrier:启用还是禁用
  16. 你还不知道什么是市场占有率?
  17. 缓存一致性问题解决方案
  18. Android增量升级
  19. JAVA1.7 NIO.2 入门,第 1 部分: 异步通道 API
  20. 计算机读写说,计算机读写PLC存储区的程序

热门文章

  1. 正大国际期货主账户:什么是外盘
  2. Elasticsearch创建一个索引怎么也这么复杂
  3. workbench焊接实例_[转载]Workbench的焊接模拟过程(高斯移动热源)
  4. 数据库 和html的交互
  5. Kinectfusion开源实现_配置Kinfu环境_Cmake编译PCL点云库_Kinect3D重建
  6. 解决 unity 按住鼠标右键 WS不能前进后退(我被自己蠢哭了)
  7. 微信小程序正则表达式判断邮箱格式
  8. iPhone 13 支持卫星上网?没那么简单
  9. 《系统设计》微服务不是银弹
  10. Lucene6.6.0 案例与学习路线