图像风格迁移其实非常好理解,就是将一张图像的“风格”(风格图像)迁移至另外一张图像(内容图像),但是这所谓的另外一张图像只是在“风格”上与之前有所不同,图像的“内容”仍要与之前相同。Luan et al. and Gatys et al. 的工作都是利用VGGNet19作为该项任务的backbone,由于VGGNet19是一种近似“金字塔”型结构,所以随着卷积操作的加深,feature maps的感受野越来越大,提取到的图像特征从局部扩展到了全局。我们为了避免合成的图像过多地保留内容信息,选取VGGNet19中位于金字塔顶部的卷积层作为内容层。整个训练过程为将生成图像初始化为内容图像,每次循环分别抽取生成图像和内容图像的内容特征,计算mse并且使之最小化,同时抽取生成图像和风格图像的样式特征,计算mse并且使之最小化。这里注意损失函数的写法:

总损失由两部分组成:内容损失和样式损失。内容损失即为生成图像和内容图像对应特征图的均方误差,但是样式损失需要分别计算生成图像和内容图像的格拉姆矩阵再做均方误差。另外,α\alphaα和β\betaβ分别为内容损失和样式损失的各项权重,Γ\GammaΓ为样式损失的惩罚系数。我通过实发现β\betaβ和Γ\GammaΓ应该取的值大些,使得样式损失被尽可能地“惩罚”,即“放大”样式损失。

import torch
import numpy as np
from PIL import Image
from torchvision.models import vgg19
from torchvision.transforms import transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.nn.functional import mse_loss
from torch.autograd import Variabledevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 预处理:大小裁剪、转为张量、归一化
def preprocess(img_shape):transform = transforms.Compose([transforms.Resize(img_shape),transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])return transformclass VGGNet19(nn.Module):def __init__(self):super(VGGNet19, self).__init__()self.vggnet19 = vgg19(pretrained=False)self.vggnet19.load_state_dict(torch.load('./vgg19-dcbb9e9d.pth'))self.content_layers = [25]self.style_layers = [0, 5, 10, 19, 28]def forward(self, x):content_features = []style_features = []for name, module in self.vggnet19.features._modules.items():x = module(x)if int(name) in self.content_layers:content_features.append(x)if int(name) in self.style_layers:style_features.append(x)return content_features, style_featuresclass GenerateImage(nn.Module):def __init__(self, img_shape):super(GenerateImage, self).__init__()self.weight = torch.nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight# 初始化生成图像为内容图像
def generate_inits(content, device, lr):g_img = GenerateImage(content.shape).to(device)g_img.weight.data = content.dataoptimizer = torch.optim.Adam(g_img.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)return g_img(), optimizer# 计算格拉姆矩阵
def gramMatrix(x):_, c, h, w = x.shapex = x.view(c, h*w)return torch.matmul(x, x.t()) / (c*h*w)# 计算总损失:内容损失+样式损失
def compute_loss(content_g, content_y, style_g, style_y, content_weight, style_weight, gamma):contentlosses = [mse_loss(g, y)*content_weight for g, y in zip(content_g, content_y)]stylelosses = [mse_loss(gramMatrix(g), gramMatrix(y))*style_weight for g, y in zip(style_g, style_y)]total_loss = sum(contentlosses) + gamma * sum(stylelosses)return contentlosses, stylelosses, total_loss# 用于可视化的后处理
def postprocess(img_tensor):rgb_mean = np.array([0.485, 0.456, 0.406])rgb_std = np.array([0.229, 0.224, 0.225])inv_normalize = transforms.Normalize(mean=-rgb_mean/rgb_std,std=1/rgb_std)to_PIL_image = transforms.ToPILImage()return to_PIL_image(inv_normalize(img_tensor[0].detach().cpu()).clamp(0, 1))def train(lr, epoch_num, c_path, s_path, img_shape):ipt = Image.open(c_path)syl = Image.open(s_path)transform = preprocess(img_shape)content, style = transform(ipt).unsqueeze(0), transform(syl).unsqueeze(0)net = VGGNet19()net.to(device).eval()content = content.type(torch.FloatTensor)style = style.type(torch.FloatTensor)if torch.cuda.is_available():content, style = Variable(content.cuda(), requires_grad=False), Variable(style.cuda(), requires_grad=False)else:content, style = Variable(content, requires_grad=False), Variable(style, requires_grad=False)icontent, istyle = net(content)scontent, sstyle = net(style)input, optimizer = generate_inits(content, device, lr)for epoch in range(epoch_num+1):gcontent, gstyle = net(input)contentlosses, stylelosses, total_loss = compute_loss(gcontent, icontent, gstyle, sstyle, 1, 1e3, 1e2)optimizer.zero_grad()total_loss.backward(retain_graph=True)optimizer.step()print("[epoch: %3d/%3d] content loss: %3f style loss: %3f total loss: %3f" % (epoch, epoch_num, sum(contentlosses).item(), sum(stylelosses).item(), total_loss))if epoch % 100 == 0 and epoch != 0:# plt.imshow(postprocess(input))# plt.axis('off')# plt.show()torch.save(net.state_dict(), "itr_%d_total_loss_%3f.pth" % (epoch, total_loss))if __name__ == "__main__":train(0.01, 10000, './content.jpg', './s.jpg', (500, 700)



内容图像、风格图像和生成图像(第10000次迭代的可视化)分别如上图所示,并且代码实现是Gatys et al.的工作。

图像风格迁移及代码实现相关推荐

  1. 《深度学习之pytorch实战计算机视觉》第8章 图像风格迁移实战(代码可跑通)

    上一章<深度学习之pytorch实战计算机视觉>第7章 迁移学习(代码可跑通)介绍了迁移学习.本章将完成一个有趣的应用,基于卷积神经网络实现图像风格迁移(Style Transfer).和 ...

  2. 将 TensorFlow 移植到 Android手机,实现物体识别、行人检测和图像风格迁移详细教程

    2017/02/23 更新 贴一个TensorFlow 2017开发者大会的Mobile专题演讲 移动和嵌入式TensorFlow 这里面有重点讲到本文介绍的三个例子,以及其他的移动和嵌入式方面的TF ...

  3. 图像迁移风格保存模型_CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介、关键步骤配图、案例应用...

    CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介.过程思路.关键步骤配图.案例应用之详细攻略 目录 图像风格迁移算法简介 图像风格迁移算法过程思路 1.VGG对比NS 图像风 ...

  4. NS之VGG(Keras):基于Keras的VGG16实现之《复仇者联盟3》灭霸图像风格迁移设计(A Neural Algorithm of Artistic Style)

    NS之VGG(Keras):基于Keras的VGG16实现之<复仇者联盟3>灭霸图像风格迁移设计(A Neural Algorithm of Artistic Style) 导读 通过代码 ...

  5. CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介、过程思路、关键步骤配图、案例应用之详细攻略

    CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介.过程思路.关键步骤配图.案例应用之详细攻略 目录 图像风格迁移算法简介 图像风格迁移算法过程思路 1.VGG对比NS 图像风 ...

  6. CVPR 2021 | 澳洲国立大学提出基于模型的图像风格迁移

    ©作者|侯云钟 学校|澳洲国立大学博士生 研究方向|计算机视觉 本文从另外一个角度解读,澳洲国立大学郑良老师实验室 CVPR 2021 新工作.一般而言,我们需要同时利用两张图片完成图像的风格迁移(s ...

  7. 图像风格迁移_【论文解读】图像风格迁移中的Contextual Loss

    [08/04更新]在前几天的Commit中,Contextual Loss已经支持多GPU训练 1.Background 对于图像风格迁移,最常用的做法就是通过GAN网络实现,然而,如果你没有很强大的 ...

  8. 深度有趣 | 30 快速图像风格迁移

    简介 使用TensorFlow实现快速图像风格迁移(Fast Neural Style Transfer) 原理 在之前介绍的图像风格迁移中,我们根据内容图片和风格图片优化输入图片,使得内容损失函数和 ...

  9. 图像风格迁移_图像风格迁移—谷歌大脑团队任意图像风格化迁移论文详解

    点击蓝字关注我们 AI研习图书馆,发现不一样的世界 风格迁移 图像风格化迁移是一个很有意思的研究领域,它可以将一张图的风格迁移到另外一张图像上,由此还诞生了Prisma和Ostagram这样的商业化产 ...

  10. 深度学习实战-图像风格迁移

    图像风格迁移 文章目录 图像风格迁移 简介 画风迁移 图像风格捕捉 图像风格迁移 图像风格内插 补充说明 简介 利用卷积神经网络实现图像风格的迁移. 画风迁移 简单来说就是将另一张图像的绘画风格在不改 ...

最新文章

  1. PingCode与Jira 敏捷开发管理能力的对比
  2. 功能测试工作的一点总结
  3. 推荐:Visual Basic.NET Windows Forms 编程
  4. Spring MVC绑定,无设置器
  5. 工作汇报ppt案例_【赠书】开工大吉!今年一定要干过写PPT的!
  6. 默认帐户生成器帐户来源
  7. python 清空文件夹_别这样直接运行Python命令,否则电脑等于“裸奔”
  8. 微软BI 之SSRS 系列 - 解决Pie Chart 中控制标签外部显示与标签重叠的问题
  9. Julia: 关于下载库时WinRPM的Bug
  10. 卡盟主站搭建_搭建卡盟主站下载|搭建卡盟主站教程 (附带源码)百度云_ - 极光下载站...
  11. Origin作图点太密集处理方法
  12. 阿里云Anolis OS 8.4
  13. SpringBoot替换jar包中引用的jar包(Unable to open nested entry ‘BOOT-INF/lib/**.jar‘. It has been compressed)
  14. ubuntu服务器安装可视化桌面(Gnome)
  15. CSP漫画工作室clipstudiopaint最新版本2022功能介绍
  16. 中英文说明丨质膜H+ATP酶AS07 260介绍
  17. 【分享】性能比肩美拍秒拍的Android视频录制编辑特效解决方案【1】
  18. html注册新浪邮箱代码,JS仿新浪邮箱点击联系人添加Email地址
  19. win11 ENSP AR启动40错误解决方法:
  20. HDU-5965 扫雷(dp / 递推)

热门文章

  1. map字符串转json格式
  2. SECS/GEM协议库开发开源代码
  3. HCIA---华为认证初级网络工程师
  4. Android 动画 Kotlin 教程
  5. Effective Python 中文版
  6. StarRocks不稳定版本(解除AVX2指令集限制)
  7. 热分析(一):什么是热仿真/热分析?
  8. What Music简单的全网音乐播放器
  9. 《Python金融大数据风控建模实战》 第5章 变量编码方法
  10. 混合动力系统的整车经济性开发与能量管理策略高级技术