易于使用的神经风格迁移框架 pystiche。

将内容图片与艺术风格图片进行融合,生成一张具有特定风格的新图,这种想法并不新鲜。早在 2015 年,Gatys、 Ecker 以及 Bethge 开创性地提出了神经风格迁移(Neural Style Transfer ,NST)。

不同于深度学习,目前 NST 还没有现成的库或框架。因此,新的 NST 技术要么从头开始实现所有内容,要么基于现有的方法实现。但这两种方法都有各自的缺点:前者由于可重用部分的冗长实现,限制了技术创新;后者继承了 DL 硬件和软件快速发展导致的技术债务。

最近,新项目 pystiche 很好地解决了这些问题,虽然它的核心受众是研究人员,但其易于使用的用户界面为非专业人员使用 NST 提供了可能。

pystiche 是一个用 Python 编写的 NST 框架,基于 PyTorch 构建,并与之完全兼容。相关研究由 pyOpenSci 进行同行评审,并发表在 JOSS 期刊 (Journal of Open Source Software) 上。

  • 论文地址:https://joss.theoj.org/papers/10.21105/joss.02761

  • 项目地址:https://github.com/pmeier/pystiche

在深入实现之前,我们先来回顾一下 NST 的原理。它有两种优化方式:基于图像的优化和基于模型的优化。虽然 pystiche 能够很好地处理后者,但更为复杂,因此本文只讨论基于图像的优化方法。

在基于图像的方法中,将图像的像素迭代调整训练,来拟合感知损失函数(perceptual loss)。感知损失是 NST 的核心部分,分为内容损失(content loss)和风格损失(style loss),这些损失评估输出图像与目标图像的匹配程度。与传统的风格迁移算法不同,感知损失包含一个称为编码器的多层模型,这就是 pystiche 基于 PyTorch 构建的原因。

如何使用 pystiche

让我们用一个例子介绍怎么使用 pystiche 生成神经风格迁移图片。首先导入所需模块,选择处理设备。虽然 pystiche 的设计与设备无关,但使用 GPU 可以将 NST 的速度提高几个数量级。

模块导入与设备选择:

import torchimport pystichefrom pystiche import demo, enc, loss, ops, optim
print(f"pystiche=={pystiche.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

输出:

pystiche==0.7.0

多层编码器

content_loss 和 style_loss 是对图像编码进行操作而不是图像本身,这些编码是由在不同层级的预训练编码器生成的。pystiche 定义了 enc.MultiLayerEncoder 类,该类在单个前向传递中可以有效地处理编码问题。该示例使用基于 VGG19 架构的 vgg19_multi_layer_encoder。默认情况下,它将加载 torchvision 提供的权重。

多层编码器:

multi_layer_encoder = enc.vgg19_multi_layer_encoder()print(multi_layer_encoder)

输出:

VGGMultiLayerEncoder(  arch=vgg19, framework=torch, allow_inplace=True  (preprocessing): TorchPreprocessing(   (0): Normalize(     mean=('0.485', '0.456', '0.406'),     std=('0.229', '0.224', '0.225')    )  ) (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu1_1): ReLU(inplace=True) (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu1_2): ReLU(inplace=True) (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu2_1): ReLU(inplace=True) (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu2_2): ReLU(inplace=True) (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_1): ReLU(inplace=True) (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_2): ReLU(inplace=True) (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_3): ReLU(inplace=True) (conv3_4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu3_4): ReLU(inplace=True) (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_1): ReLU(inplace=True) (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_2): ReLU(inplace=True) (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_3): ReLU(inplace=True) (conv4_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu4_4): ReLU(inplace=True) (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_1): ReLU(inplace=True) (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_2): ReLU(inplace=True) (conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_3): ReLU(inplace=True) (conv5_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (relu5_4): ReLU(inplace=True) (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))

感知损失

pystiche 将内容损失和风格损失定义为操作符。使用 ops.FeatureReconstructionOperator 作为 content_loss,直接与编码进行对比。如果编码器针对分类任务进行过训练,如该示例中这些编码表示内容。对于content_layer,选择 multi_layer_encoder 的较深层来获取抽象的内容表示,而不是许多不必要的细节。​​​​​​​

content_layer = "relu4_2"encoder = multi_layer_encoder.extract_encoder(content_layer)content_loss = ops.FeatureReconstructionOperator(encoder)

pystiche 使用 ops.GramOperator 作为 style_loss 的基础,通过比较编码各个通道之间的相关性来丢弃空间信息。这样就可以在输出图像中的任意区域合成风格元素,而不仅仅是风格图像中它们所在的位置。对于 ops.GramOperator,如果它在浅层和深层 style_layers 都能很好地运行,则其性能达到最佳。

style_weight 可以控制模型对输出图像的重点——内容或风格。为了方便起见,pystiche 将所有内容包装在 ops.MultiLayerEncodingOperator 中,该操作处理在同一 multi_layer_encoder 的多个层上进行操作的相同类型操作符的情况。​​​​​​​

style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")style_weight = 1e3def get_encoding_op(encoder, layer_weight):    return ops.GramOperator(encoder, score_weight=layer_weight)style_loss = ops.MultiLayerEncodingOperator(    multi_layer_encoder, style_layers, get_encoding_op, score_weight=style_weight,)

loss.PerceptualLoss 结合了 content_loss 与 style_loss,将作为优化的标准。

criterion = loss.PerceptualLoss(content_loss, style_loss).to(device)print(criterion)

输出:​​​​​​​

PerceptualLoss( (content_loss): FeatureReconstructionOperator(   score_weight=1,   encoder=VGGMultiLayerEncoder(     layer=relu4_2,     arch=vgg19,     framework=torch,     allow_inplace=True   ) ) (style_loss): MultiLayerEncodingOperator(   encoder=VGGMultiLayerEncoder(     arch=vgg19,     framework=torch,     allow_inplace=True ), score_weight=1000 (relu1_1): GramOperator(score_weight=0.2) (relu2_1): GramOperator(score_weight=0.2) (relu3_1): GramOperator(score_weight=0.2) (relu4_1): GramOperator(score_weight=0.2) (relu5_1): GramOperator(score_weight=0.2) ))

图像加载

首先加载并显在 NST 需要的目标图片。因为 NST 占用内存较多,故将图像大小调整为 500 像素。​​​​​​​

size = 500images = demo.images()​​​​​​
content_image = images["bird1"].read(size=size, device=device)criterion.set_content_image(content_image)

内容图片​​​​​​​

style_image = images["paint"].read(size=size, device=device)criterion.set_style_image(style_image)

风格图片

神经风格迁移

创建 input_image。从 content_image 开始执行 NST,这样可以实现快速收敛。image_optimization 函数是为了方便,也可以由手动优化循环代替,且不受限制。如果没有指定,则使用 torch.optim.LBFGS 作为优化器。​​​​​​​

input_image = content_image.clone()output_image = optim.image_optimization(input_image, criterion, num_steps=500)

图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用相关推荐

  1. 图像迁移风格保存模型_图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用...

    原标题:图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用 选自Medium 作者:Philip Meier 机器之心编译 编辑:陈萍 易于使用的神经风格迁移框架 py ...

  2. 图像风格迁移也有框架了

    选自Medium 作者:Philip Meier 机器之心编译 编辑:陈萍 易于使用的神经风格迁移框架 pystiche. 将内容图片与艺术风格图片进行融合,生成一张具有特定风格的新图,这种想法并不新 ...

  3. cnn风格迁移_快速图像风格迁移思想在无线通信中的另类应用:算法拟合

    在本文中,并不是介绍最新的一些论文,而是回顾自己在很早(半年前?)读过的几篇文章.[1]Learning to optimize: Training deep neural networks for ...

  4. 图像风格迁移_图像风格迁移—谷歌大脑团队任意图像风格化迁移论文详解

    点击蓝字关注我们 AI研习图书馆,发现不一样的世界 风格迁移 图像风格化迁移是一个很有意思的研究领域,它可以将一张图的风格迁移到另外一张图像上,由此还诞生了Prisma和Ostagram这样的商业化产 ...

  5. 【人工智能专题】基于 GAN 的艺术风格化——图像风格迁移

    原文:https://mp.weixin.qq.com/s?__biz=MzAxMzEwMDM2Mg==&mid=2652847175&idx=3&sn=51dcb41bc5c ...

  6. 图像风格迁移做了一件文化衫-【布尔艺数】

    互联网人的夏天 一定少不了件又潮又酷的文化衫. 既要潮又要酷!这可难坏了设计小伙伴- 赶紧召集大家一起出点子! 一番讨论后,大家一致认为: Hinton !是业界最潮最酷的人! Geoff Hinto ...

  7. java图像风格迁移_Python+OpenCV 图像风格迁移(模仿名画)

    现在很多人都喜欢拍照(自拍).有限的滤镜和装饰玩多了也会腻,所以就有 APP 提供了模仿名画风格的功能,比如 prisma.versa 等,可以把你的照片变成 梵高.毕加索.蒙克 等大师的风格. 这种 ...

  8. Pytorch实现图像风格迁移(一)

    图像风格迁移是图像纹理迁移研究的进一步拓展,可以理解为针对一张风格图像和一张内容图像,通过将风格图像的风格添加到内容图像上,从而对内容图像进行进一步创作,获得具有不同风格的目标图像.基于深度学习网络的 ...

  9. Pix2Pix——基于GAN的图像风格迁移模型

    Pix2Pix--基于GAN的图像风格迁移模型 写在前面 本文是文献Image-to-image translation with conditional adversarial networks的笔 ...

最新文章

  1. SharePoint 2013 表单认证使用ASP.Net配置工具添加用户
  2. Validator(二)自定义
  3. 操作系统原理第六章:进程同步
  4. R语言实战应用精讲50篇(二十七)-时空数据分析-经验空间/时间均值(latex公式+R代码绘图)
  5. python处理字符串效率_Python字符串搜索效率
  6. try catch finally 关闭流标准的写法
  7. Java路径问题最终解决方案—可定位所有资源的相对路径寻址
  8. python 捕获鼠标点击事件,在Python中的wx.Frame外部捕获鼠标事件
  9. (33)Verilog HDL缩减运算
  10. 自己手写一个Spring MVC框架
  11. UltraEdit搭建python IDE环境+设置快捷键
  12. kibana报错:No default index pattern. You must select or create one to continue.
  13. 收藏! | 入门必读:计算机视觉四大基本任务(分类、定位、检测、分割)
  14. 跨服务器导入数据sql
  15. 设计模式 AOP 面向切入编程
  16. django、tornado、flask对比
  17. JSTL不同版本和EL表达式的关联
  18. 实现音乐播放器歌词显示效果
  19. 在 COMSOL 中模拟地震波的传播
  20. 简单图文解释冯诺依曼体系结构(通俗易懂版)

热门文章

  1. 一文揭示DisCO的内幕,区块链技术才是创造未来价值的最好选择?
  2. 如何处理阿里云ECS服务器提示存在漏洞
  3. 关于《流浪地球》中的春节12响(Spring 12 biu)——C#.Net
  4. IE浏览器突然无法打开
  5. 最新Zblog博客微信小程序源码全开源完整版+带教程
  6. 处理器有k和无k有什么区别?
  7. “老票证”述时代变迁:从凭“票”买到“任意”购
  8. Java使用ffmpeg实现视频剪切、mp3剪切
  9. matlab怎么输入sin,在 MATLAB命令窗口中输入:a=sin(pi/7);save a,则下列论述中正确的是:...
  10. 商业级手术麻醉系统源码,术前分析、用药、评级,术后访视、麻醉科室管理、数据统计分析