作者 | 李秋键

责编 | Carol

出品 | AI科技大本营(ID:rgznai100)

近几天一个GitHub项目火遍了朋友圈,那就是卡通头像AI生成小程序。如下图所见:

而这个项目的基本原理是用python搭建的GAN算法模型,进行训练得出。

而所谓的GAN就是指生成对抗网络深度学习模型。网络中有生成器G(generator)和鉴别器(Discriminator)。有两个数据域分别为X,Y。G 负责把X域中的数据拿过来拼命地模仿成真实数据并把它们藏在真实数据中,而 D 就拼命地要把伪造数据和真实数据分开。经过二者的博弈以后,G 的伪造技术越来越厉害,D 的鉴别技术也越来越厉害。直到 D 再也分不出数据是真实的还是 G 生成的数据的时候,这个对抗的过程达到一个动态的平衡。

而CycleGAN本质上是两个镜像对称的GAN,构成了一个环形网络。

两个GAN共享两个生成器,并各自带一个判别器,即共有两个判别器和两个生成器。一个单向GAN两个loss,两个即共四个loss。

可以实现无配对的两个图片集的训练是CycleGAN与Pixel2Pixel相比的一个典型优点。但是我们仍然需要通过训练创建这个映射来确保输入图像和生成图像间存在有意义的关联,即输入输出共享一些特征。

简而言之,该模型通过从域DA获取输入图像,该输入图像被传递到第一个生成器GeneratorA→B,其任务是将来自域DA的给定图像转换到目标域DB中的图像。然后这个新生成的图像被传递到另一个生成器GeneratorB→A,其任务是在原始域DA转换回图像,这里可与自动编码器作对比。这个输出图像必须与原始输入图像相似,用来定义非配对数据集中原来不存在的有意义映射。

在本次的项目中就是利用了CycleGAN进行搭建模型。模型训练数据集如下:

实验前的准备

首先我们使用的python版本是3.6.5所用到的库有pytorch和TensorFlow,用来训练和加载神经网络常见的框架;face-alignment用来是用来提取人脸特征的常用库;

dlib是一个机器学习的开源库,包含了机器学习的很多算法,使用起来很方便,直接包含头文件即可,并且不依赖于其他库(自带图像编解码库源码)。Dlib可以帮助您创建很多复杂的机器学习方面的软件来帮助解决实际问题。目前Dlib已经被广泛的用在行业和学术领域,包括机器人,嵌入式设备,移动电话和大型高性能计算环境。

模型的训练

1、数据集处理和准备:

训练数据包括真实照片和卡通画像,为降低训练复杂度,我们对两类数据进行了如下预处理:

· 检测人脸及关键点。

· 根据关键点旋转校正人脸。

· 将关键点边界框按固定的比例扩张并裁剪出人脸区域。

· 使用人像分割模型将背景置白。

为了形成匹配效果,需要准备一些卡通人物图片和真实的人脸图片进行训练

2、模型的训练:

模型的训练使用python train.py --dataset photo2cartoon进行训练即可。

3、神经网络结构搭建:

整个算法的搭建正如上面可见,需要有生成器和判别器。使用论文提出的一种Soft-AdaLIN(Soft Adaptive Layer-Instance Normalization)归一化方法,在反规范化时将编码器的均值方差(照片特征)与解码器的均值方差(卡通特征)相融合。

模型结构方面,在U-GAT-IT的基础上,在编码器之前和解码器之后各增加了2个hourglass模块,渐进地提升模型特征抽象和重建能力。

部分代码如下:

class ResnetGenerator(nn.Module):def __init__(self, ngf=64, img_size=256, light=False):super(ResnetGenerator, self).__init__()self.light = lightself.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3),nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False),nn.InstanceNorm2d(ngf),nn.ReLU(True))self.HourGlass1 = HourGlass(ngf, ngf)self.HourGlass2 = HourGlass(ngf, ngf)# Down-Samplingself.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False),nn.InstanceNorm2d(ngf * 2),nn.ReLU(True))self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False),nn.InstanceNorm2d(ngf*4),nn.ReLU(True))# Encoder Bottleneckself.EncodeBlock1 = ResnetBlock(ngf*4)self.EncodeBlock2 = ResnetBlock(ngf*4)self.EncodeBlock3 = ResnetBlock(ngf*4)self.EncodeBlock4 = ResnetBlock(ngf*4)# Class Activation Mapself.gap_fc = nn.Linear(ngf*4, 1)self.gmp_fc = nn.Linear(ngf*4, 1)self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1)self.relu = nn.ReLU(True)# Gamma, Beta blockif self.light:self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4),nn.ReLU(True),nn.Linear(ngf*4, ngf*4),nn.ReLU(True))else:self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4),nn.ReLU(True),nn.Linear(ngf*4, ngf*4),nn.ReLU(True))# Decoder Bottleneckself.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4)self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4)self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4)self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4)# Up-Samplingself.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2),nn.ReflectionPad2d(1),nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False),LIN(ngf*2),nn.ReLU(True))self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2),nn.ReflectionPad2d(1),nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False),LIN(ngf),nn.ReLU(True))self.HourGlass3 = HourGlass(ngf, ngf)self.HourGlass4 = HourGlass(ngf, ngf, False)self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3),nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False),nn.Tanh())def forward(self, x):x = self.ConvBlock1(x)x = self.HourGlass1(x)x = self.HourGlass2(x)x = self.DownBlock1(x)x = self.DownBlock2(x)x = self.EncodeBlock1(x)content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)x = self.EncodeBlock2(x)content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)x = self.EncodeBlock3(x)content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)x = self.EncodeBlock4(x)content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1)gap = F.adaptive_avg_pool2d(x, 1)gap_logit = self.gap_fc(gap.view(x.shape[0], -1))gap_weight = list(self.gap_fc.parameters())[0]gap = x * gap_weight.unsqueeze(2).unsqueeze(3)gmp = F.adaptive_max_pool2d(x, 1)gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))gmp_weight = list(self.gmp_fc.parameters())[0]gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)cam_logit = torch.cat([gap_logit, gmp_logit], 1)x = torch.cat([gap, gmp], 1)x = self.relu(self.conv1x1(x))heatmap = torch.sum(x, dim=1, keepdim=True)if self.light:x_ = F.adaptive_avg_pool2d(x, 1)style_features = self.FC(x_.view(x_.shape[0], -1))else:style_features = self.FC(x.view(x.shape[0], -1))x = self.DecodeBlock1(x, content_features4, style_features)x = self.DecodeBlock2(x, content_features3, style_features)x = self.DecodeBlock3(x, content_features2, style_features)x = self.DecodeBlock4(x, content_features1, style_features)x = self.UpBlock1(x)x = self.UpBlock2(x)x = self.HourGlass3(x)x = self.HourGlass4(x)out = self.ConvBlock2(x)return out, cam_logit, heatmap

4、提取人脸特征:

为了提取人脸特征以达到加载到网络中的目的,我们需要正确框出人脸同时计算特征距离,以方便后面训练模型师损失函数的调用。

代码如下:

class FaceFeatures(object):def __init__(self, weights_path, device):self.device = deviceself.model = MobileFaceNet(512).to(device)self.model.load_state_dict(torch.load(weights_path))self.model.eval()def infer(self, batch_tensor):# crop faceh, w = batch_tensor.shape[2:]top = int(h / 2.1 * (0.8 - 0.33))bottom = int(h - (h / 2.1 * 0.3))size = bottom - topleft = int(w / 2 - size / 2)right = left + sizebatch_tensor = batch_tensor[:, :, top: bottom, left: right]batch_tensor = F.interpolate(batch_tensor, size=[112, 112], mode='bilinear', align_corners=True)features = self.model(batch_tensor)return featuresdef cosine_distance(self, batch_tensor1, batch_tensor2):feature1 = self.infer(batch_tensor1)feature2 = self.infer(batch_tensor2)return 1 - torch.cosine_similarity(feature1, feature2)

模型测试

在训练好模型后,我们使用python test.py --photo_path ./images/1.jpg --save_path ./images/2.png测试生成图片。其中1.jpg是原始图片,最终会生成2.jpg图片。

使用python data_process.py --data_path YourPhotoFolderPath --save_path YourSaveFolderPath批量生成

1、调用模型:

调用模型首先要使用torch进行加载模型,读取神经网络参数。在对原始图片提取人脸特征的基础上,加载进网络进行生成即可。因为这里我们还需要对生成的数据进行转换成图片,我们这里还需要使用numpy和opencv进行图片的转化。因为加载如模型和模型生成的必然是数据,而我们需要将生成器产生的数据再转换为图片,就用到了这两个库。

代码如下:

class Photo2Cartoon:def __init__(self):self.pre = Preprocess()self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")self.net = ResnetGenerator(ngf=32, img_size=256, light=True).to(self.device)params = torch.load('./models/photo2cartoon_weights.pt', map_location=self.device)self.net.load_state_dict(params['genA2B'])def inference(self, img):# face alignment and segmentationface_rgba = self.pre.process(img)if face_rgba is None:print('can not detect face!!!')return Noneface_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)face = face_rgba[:, :, :3].copy()mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.face = (face*mask + (1-mask)*255) / 127.5 - 1face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)face = torch.from_numpy(face).to(self.device)# inferencewith torch.no_grad():cartoon = self.net(face)[0][0]# post-processcartoon = np.transpose(cartoon.cpu().numpy(), (1, 2, 0))cartoon = (cartoon + 1) * 127.5cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)return cartoon
if __name__ == '__main__':img = cv2.cvtColor(cv2.imread(args.photo_path), cv2.COLOR_BGR2RGB)c2p = Photo2Cartoon()cartoon = c2p.inference(img)if cartoon is not None:cv2.imwrite(args.save_path, cartoon)

到这里,我们整体的程序就搭建完成,下面为我们程序的运行结果:

在这里附上源码地址:

链接:https://pan.baidu.com/s/1jYVt8T0IPqpYmuNIRyvNGg

提取码:54vp

作者简介

李秋键,CSDN 博客专家,CSDN达人课作者。硕士在读于中国矿业大学,开发有taptap安卓武侠游戏一部,vip视频解析,文意转换工具,写作机器人等项目,发表论文若干,多次高数竞赛获奖等等。

推荐阅读

  • 干货 | 时间序列预测类问题下的建模方案探索实践

  • 饿了么交易系统 5 年演化史

  • 360金融首席科学家张家兴:别指望AI Lab做成中台

  • 十六位顶尖专家齐聚,解密阿里云最新核心技术竞争力!

  • 利用 Docker 在不同宿主机做 CentOS 系统容器 | 原力计划

  • 从技术原理解析区块链为何列入新基建

  • 你点的每个“在看”,我都认真当成了AI

CycleGan人脸转为漫画脸,牛掰的知识又增加了!| 附代码相关推荐

  1. 人脸变漫画脸!AI 教你轻松 Pick 胡歌漫画头像

    作者 | 李秋键 责编 | Carol 出品 | AI科技大本营(ID:rgznai100) 近几天一个GitHub项目火遍了朋友圈,那就是卡通头像AI生成小程序.如下图所见: 而这个项目的基本原理是 ...

  2. 怎样P漫画脸?这三个简单方法分享给你

    大家在看朋友圈或者一些社交平台时,有没有看到有人分享了一些自己漫画脸的照片,不知道你们是否和我一样看到这些有趣的照片时,也会产生一点好奇心,想知道他们是怎么制作出来的.后来自己研究后发现,其实我们也可 ...

  3. 刷新纪录,揭秘漫画脸背后的AI技术

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要8分钟 Follow小博主,每天更新前沿干货 作者 | 贝爽,本文转自雷锋网 昨晚做了一个梦,梦里的我变成漫画里的人物,正在为参与选秀苦练舞蹈 ...

  4. 漫画脸特效工具有什么?这些软件值得收藏

    相信小伙伴们在社交平台刷动态的时候,经常会看到不少人的头像是使用漫画脸吧.你们会不会也想要使用这样的漫画脸头像呢?有些动手能力强的小伙伴,可以自己画出这些漫画脸头像,.画不出来的小伙伴也不用担心,其实 ...

  5. 你知道漫画脸怎么生成的吗

    随着现在信息科技发展迅速,大家对漫画的喜欢程度有增加了,许多爱好者都有对漫画世界的向往,会对自己在漫画世界的形象产生好奇.那么有什么软件能实现漫画脸转换呢?漫画脸怎么生成呢?以下软件能解决你所需的问题 ...

  6. 有“肌肉”有“血管”!波兰团队耗时5年研发超逼真仿生机械臂,网友:很怪异也很牛掰...

    来源:大数据文摘 本文约1800字,建议阅读5分钟 本文为你介绍波兰团队研发的真仿生机械臂. 1970年,日本机器人专家森政弘提出,一个看上去和人类无异的实体可能会让人们产生一种寒冷和怪异的感觉. 那 ...

  7. 想非常牛掰的在WORD中调中多级符号吗

    今天从同事那学了一招,简单.快捷还牛掰! 方法是:1.你要先调出一个一级符号,然后在把光标指向这个一级符号上并单击选定 2.华丽的按下"TAB"键. 3.完工了! 你会发现第二级符 ...

  8. 为什么公司的HR这么牛掰

    前言 作为开发岗的程序员,我们在公司印象中最深的行政人员恐怕就是HR了,毕竟其他的公司行政人员,像那种财务或者采购之类,可能根本就接触不到,唯一除外的就是前台小姐姐. 网上投简历的时候,给你聊天打电话 ...

  9. java json解码器_Jackson:我是最牛掰的 Java JSON 解析器(有点虚)

    在当今的编程世界里,JSON 已经成为将信息从客户端传输到服务器端的首选协议,可以好不夸张的说,XML 就是那个被拍死在沙滩上的前浪. 很不幸的是,JDK 没有 JSON 库,不知道为什么不搞一下.L ...

最新文章

  1. 微服务之consul(一) - 诗码者 - 博客园
  2. 00002-两数之和-leetcode-1.暴力法(枚举法),2.哈希表法,目前更新了枚举法
  3. 【Java】基于IDE的JUnit软件测试入门
  4. 【蓝桥杯单片机】定时器和中断
  5. net4.0 程序没反应_笔记本触摸板没反应原因 笔记本触摸板没反应解决方法【详解】...
  6. ubuntu Mendeley Desktop 安装
  7. POJ3581 后缀数组
  8. ST电机库5.0完全开源对电机控制软件工程师有何影响?
  9. win7计算机管理没有用户模块,Win7系统下安装ipx协议提示找不到相应的模块如何解决...
  10. ShuffleNet原理
  11. ubuntu中文智能拼音输入法配置
  12. 什么是SAAS模式网站?
  13. 大数据创造智慧城市的未来之光!
  14. php doctrine datetime,php – doctrine和Symfony 2中的DateTime字段
  15. 【身体这些部位不舒服的时候,你知道意味着什么吗?】
  16. MobIM 使用总结
  17. 点石成金:“硅业报国”不仅是理念
  18. linux中execve函数的用法
  19. 分享:提升你工作幸福感的11个工具软件
  20. python读取idx_通过Python从.idx3-ubyte文件或GZIP中提取图像 - python

热门文章

  1. 无线网卡+kali实现wifi暴力破解(密码爆破方式)
  2. batT脚本如何自动执行 adb shell 以后的命令(android抓包)
  3. 全球变暖迫在眉睫碳中和势在必行 碳森羿建议提前布局
  4. 创新型中小企业认定评定标准
  5. android小错误:Failure retrieving text 0x7f050001 in package
  6. Python 之 列表推导式
  7. 图像相似度匹配——距离大全
  8. Java中的数值计算
  9. Java 输入汉字姓名 输出 姓名拼音 首字母缩写组合
  10. 机器学习基本概念知识汇