前言

先看下效果,我实在没有拍过学校的照片,随便谷歌了一张,学校是哈尔滨理工大学荣成校区。

Github代码我已经开源在文末,环境我使用的是Colab pro,下载直接运行。(别忘了Star~)大家可以用好看的照片哦!

输出图像:

输入图像:

样式迁移

如果你是一位摄影爱好者,你也许接触过滤镜。它能改变照片的颜色样式,从而使风景照更加锐利或者令人像更加美白。但一个滤镜通常只能改变照片的某个方面。如果要照片达到理想中的样式,你可能需要尝试大量不同的组合。这个过程的复杂程度不亚于模型调参。

这里我们需要两张输入图像:一张是内容图像,另一张是样式图像。

我们将使用神经网络修改内容图像,使其在样式上接近样式图像。

例如,图像为本书作者在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而样式图像则是一幅主题为秋天橡树的油画。

最终输出的合成图像应用了样式图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。

2.1 方法

简单的例子阐述了基于卷积神经网络的样式迁移方法。

首先,我们初始化合成图像,例如将其初始化为内容图像。该合成图像是样式迁移过程中唯一需要更新的变量,即样式迁移所需迭代的模型参数。然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。

这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或样式特征。

接下来,我们通过正向传播(实线箭头方向)计算样式迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。

样式迁移常用的损失函数由3部分组成:

  1. 内容损失使合成图像与内容图像在内容特征上接近;

  2. 样式损失使合成图像与样式图像在样式特征上接近;

  3. 总变差损失则有助于减少合成图像中的噪点。

最后,当模型训练结束时,我们输出样式迁移的模型参数,即得到最终的合成图像。

这里选取的预训练的神经网络含有3个卷积层,其中第二层输出内容特征,第一层和第三层输出样式特征。

和无监督学习不是一个道理哈。

这句话很重要哦。

2.2 数据处理和网络实现

第一步,统一图像尺寸。

下面,定义图像的预处理函数和后处理函数。预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。由于图像打印函数要求每个像素的浮点数值在0到1之间,我们对小于0和大于1的值分别取0和1。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0)def postprocess(img):img = img[0].to(rgb_std.dnet = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])evice)img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

基于ImageNet数据集预训练的VGG-19模型来抽取图像特征

pretrained_net = torchvision.models.vgg19(pretrained=True)

使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或样式层之间的所有层。下面构建一个新的网络net,它只保留需要用到的VGG的所有层。

net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])

下面定义两个函数:get_contents函数对内容图像抽取内容特征;get_styles函数对样式图像抽取样式特征。因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和样式特征。由于合成图像是样式迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和样式特征。

def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Ydef get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y

2.3 训练

在训练模型进行样式迁移时,我们不断抽取合成图像的内容特征和样式特征,然后计算损失函数。下面定义了训练循环。

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)l.backward()trainer.step()scheduler.step()if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return X

现在我们训练模型:首先将内容图像和样式图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

device, image_shape = d2l.try_gpu(), (300, 450)net = net.to(device)content_X, contents_Y = get_contents(image_shape, device)_, styles_Y = get_styles(image_shape, device)output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

  • 样式迁移常用的损失函数由3部分组成:(i) 内容损失使合成图像与内容图像在内容特征上接近;(ii) 样式损失令合成图像与样式图像在样式特征上接近;(iii) 总变差损失则有助于减少合成图像中的噪点。

  • 我们可以通过预训练的卷积神经网络来抽取图像的特征,并通过最小化损失函数来不断更新合成图像来作为模型参数。

  • 我们使用格拉姆矩阵表达样式层输出的样式。

开源代码

代码地址:https://github.com/lixiang007666/Style-Transfer-pytorch

运行style-pytorch.ipynb:

训练300个epoch结果:

epoch  10, content loss 25.22, style loss 3014.07, TV loss 1.16, 0.01 sec
epoch  20, content loss 29.34, style loss 740.11, TV loss 1.31, 0.00 sec
epoch  30, content loss 30.87, style loss 383.17, TV loss 1.36, 0.00 sec
epoch  40, content loss 31.51, style loss 250.63, TV loss 1.40, 0.01 sec
epoch  50, content loss 31.39, style loss 190.49, TV loss 1.45, 0.01 sec
epoch  60, content loss 30.82, style loss 152.23, TV loss 1.46, 0.01 sec
epoch  70, content loss 29.83, style loss 124.40, TV loss 1.49, 0.01 sec
epoch  80, content loss 29.00, style loss 108.24, TV loss 1.50, 0.01 sec
epoch  90, content loss 28.27, style loss 92.64, TV loss 1.52, 0.01 sec
epoch 100, content loss 27.65, style loss 82.47, TV loss 1.53, 0.00 sec
epoch 110, content loss 27.15, style loss 73.10, TV loss 1.54, 0.01 sec
epoch 120, content loss 26.44, style loss 65.02, TV loss 1.56, 0.01 sec
epoch 130, content loss 25.90, style loss 58.60, TV loss 1.57, 0.01 sec
epoch 140, content loss 25.44, style loss 53.61, TV loss 1.58, 0.01 sec
epoch 150, content loss 24.98, style loss 49.11, TV loss 1.59, 0.00 sec
epoch 160, content loss 24.60, style loss 45.28, TV loss 1.60, 0.01 sec
epoch 170, content loss 24.11, style loss 42.02, TV loss 1.61, 0.01 sec
epoch 180, content loss 23.78, style loss 39.58, TV loss 1.61, 0.01 sec
epoch 190, content loss 23.41, style loss 37.26, TV loss 1.62, 0.01 sec
epoch 200, content loss 23.05, style loss 35.32, TV loss 1.62, 0.00 sec
epoch 210, content loss 22.81, style loss 33.80, TV loss 1.62, 0.00 sec
epoch 220, content loss 22.49, style loss 32.43, TV loss 1.62, 0.00 sec
epoch 230, content loss 22.19, style loss 31.25, TV loss 1.62, 0.01 sec
epoch 240, content loss 21.94, style loss 29.98, TV loss 1.62, 0.00 sec
epoch 250, content loss 21.65, style loss 28.75, TV loss 1.62, 0.00 sec
epoch 260, content loss 21.44, style loss 27.63, TV loss 1.62, 0.01 sec
epoch 270, content loss 21.19, style loss 26.77, TV loss 1.62, 0.01 sec
epoch 280, content loss 20.97, style loss 25.81, TV loss 1.62, 0.01 sec
epoch 290, content loss 20.81, style loss 24.97, TV loss 1.62, 0.01 sec
epoch 300, content loss 20.57, style loss 24.25, TV loss 1.62, 0.01 sec

参考

[1].https://zh-v2.d2l.ai/index.html

注:本文仅代表作者个人观点。如有不同看法,欢迎留言反馈/讨论。

‍‍‍‍‍‍‍作者:李响Superb,CSDN百万访问量博主,普普通通男大学生,深度学习算法、医学图像处理专攻,偶尔也搞全栈开发,没事就写文章。

博客地址:lixiang.blog.csdn.net

‍‍‍‍‍‍‍

—End—


往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑黄海广老师《机器学习课程》视频课
本站qq群851320808,加入微信群请扫码:

【深度学习】用Pytorch给你的母校做一个样式迁移吧!相关推荐

  1. 用Pytorch给你的母校做一个样式迁移吧!

    个人简介:李响Superb,CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章. 博客地址:lixiang.blog.csdn.net 文章目录 ...

  2. 深度学习准「研究僧」预习资料:图灵奖得主Yann LeCun《深度学习(Pytorch)》春季课程...

    视学算法报道 编辑:蛋酱 转载自公众号:机器之心 开学进入倒计时,深度学习方向的准「研究僧」们,你们准备好了吗? 转眼 2020 年已经过半,又一届深度学习方向的准研究生即将踏上「炼丹」之路.对于这一 ...

  3. 深度学习之PyTorch物体检测

    深度学习之PyTorch物体检测 董洪义 著 ISBN:9787111641742 包装:平装 开本:16开 用纸:胶版纸 出版社:机械工业出版社 出版时间:2020-01-01

  4. 深度学习框架PyTorch快速开发与实战

    深度学习框架PyTorch快速开发与实战 邢梦来,王硕,孙洋洋 著 ISBN:9787121345647 包装:平装 开本:16开 用纸:胶版纸 正文语种:中文 出版社:电子工业出版社 出版时间:20 ...

  5. 《动手学深度学习》PyTorch版GitHub资源

    之前,偶然间看到过这个PyTorch版<动手学深度学习>,当时留意了一下,后来,着手学习pytorch,发现找不到这个资源了.今天又看到了,赶紧保存下来. <动手学深度学习>P ...

  6. 【深度学习】Pytorch的深度神经网络剪枝应用

    [深度学习]Pytorch的深度神经网络剪枝应用 文章目录 1 概述 2 pytorch基于卷积层通道剪枝的方法 3 模型剪枝:Learning Efficient Convolutional Net ...

  7. Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配、对应版本安装之详细攻略

    Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配.对应版本安装之详细攻略 目录 深度学习中pytorch/torchvision版本和CUDA版本最正确 ...

  8. DL:深度学习框架Pytorch、 Tensorflow各种角度对比

    DL:深度学习框架Pytorch. Tensorflow各种角度对比 目录 先看两个框架实现同样功能的代码 1.Pytorch.Tensorflow代码比较 2.Tensorflow(数据即是代码,代 ...

  9. DL框架之PyTorch:深度学习框架PyTorch的简介、安装、使用方法之详细攻略

    DL框架之PyTorch:PyTorch的简介.安装.使用方法之详细攻略 DL框架之PyTorch:深度学习框架PyTorch的简介.安装.使用方法之详细攻略 目录 PyTorch的简介 1.pyto ...

最新文章

  1. 【 Verilog HDL 】HDL的三种描述方式
  2. 关于read的例子和条件测试
  3. jmeter操作练习
  4. Java正则表达式简单用法
  5. unix linux 命令参考,Unix/Linux 命令参考
  6. 使用sikuli和Arquillian测试HTML5 canvas应用程序
  7. 到达什么水平才能算是学会了数学?
  8. C++ opengl 对OpenGL中矩阵设置的初步认识
  9. mysql ip 访问_MySql通过ip地址进行访问的方法
  10. java获取不到ipv6的网卡
  11. 麻烦缠身的高通“向前看”:关注服务器市场和5G
  12. t检验、t分布、t值
  13. 安装Aras Innovator12 sp9全过程
  14. 鸡兔同笼——算法详解
  15. 网络带宽相关知识和计算
  16. python in arcgis_终于晓得arcgis-python入门教程
  17. 如何设计可靠性UDP传输协议?
  18. 真正的高手,都有对抗“熵增”的底层思维
  19. 配置Logback日志
  20. 百年孤独——雪融化的时刻之达人论战

热门文章

  1. BZOJ4653 尺取法 + 线段树
  2. elk集群配置配置文件中节点数配多少
  3. layui富文本编译器添加图片
  4. oracle字符集查看、修改、版本查看
  5. Inherits、CodeFile、CodeBehind
  6. 关于Jquery中ajax方法data参数用法的总结
  7. C++虚函数表解析(转) ——写的真不错
  8. 想要一篇高分SCI,这些临床统计的诀窍你要知道
  9. pycharm中无法安装scipy、imread、GDAL等库
  10. Vue-Cli 学习整理【转载】