【6.1】图片风格迁移 Neural Style Transfer
完整代码:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import numpy as npimport matplotlib.pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# -------------------------------------------------加载图片---------------------------------------------------------------
# 加载图片 + 图像预处理为相同的shape,这样vgg提取出来feature vector才是一样的大小,否则不能直接计算L2 loss
def load_image(image_path, transform=None, max_size=None, shape=None):image = Image.open(image_path)if max_size:scale = max_size / max(image.size)size = np.array(image.size) * scaleimage = image.resize(size.astype(int), Image.ANTIALIAS)if shape:image = image.resize(shape, Image.LANCZOS)if transform:image = transform(image).unsqueeze(0)return image.to(device)transform = transforms.Compose([transforms.ToTensor(),# 因为我要用VGG,在ImageNet上做的处理,transforms.Normalize(mean=[0.485, 0.456, 0.406],std= [0.229, 0.224, 0.225])
]) # 来自ImageNet的mean和variance# 此处是经过标准化后的照片
content = load_image("png/content.png", transform, max_size=400)
style = load_image("png/style.png", transform, shape=[content.size(2), content.size(3)]) # 这里是想得到和content大小一样的样式print(content.shape,style.shape) # torch.Size([1, 3, 400, 311]) torch.Size([1, 3, 311, 400])# -------------------------------------------------图片展示给大家看---------------------------------------------------------
unloader = transforms.ToPILImage() # reconvert into PIL image
plt.ion()
def imshow(tensor, title=None):image = tensor.cpu().clone() # we clone the tensor to not do changes on itimage = image.squeeze(0) # remove the fake batch dimensionimage = unloader(image)plt.imshow(image)if title is not None:plt.title(title)plt.pause(0.001) # pause a bit so that plots are updatedplt.figure()
imshow(style[0], title='Image')# 并不是训练这个VGGNet,他只是一个特征提取器,真正要优化的是这一张target图片
class VGGNet(nn.Module):def __init__(self):super(VGGNet, self).__init__()# 有些层取出来当feature,基本上就可以拿到图片的内容和textureself.select = ['0', '5', '10', '19', '28']# 拿到VGG network,此时我们只需要features部分(即只需要拿到这些层就可以了,其余的信息不需要)self.vgg = models.vgg19(pretrained=True).features# 取出 self.select 中的层,组成一个新的featuresdef forward(self, x):features = []for name, layer in self.vgg._modules.items(): # _modules 可以把vgg一层层拿出来x = layer(x)if name in self.select:features.append(x)return features# target拿到的就是和content.png内容上相似,但是风格上更倾向于style.png的图片,内容是会变化的所以requires_grad_
target = content.clone().requires_grad_(True)
# 优化的是target这张图片,
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval() # 所以设置为eval(),它是不会被优化的# 打印出每一层拿到的 feature vector
# torch.Size([1, 64, 400, 311])
# torch.Size([1, 128, 200, 155])
# torch.Size([1, 256, 100, 77])
# torch.Size([1, 512, 50, 38])
# torch.Size([1, 512, 25, 19])
feature = vgg(content) # list,里面包含不同的features
for feat in feature:print(feat.shape)# 开始优化我们的target图片
target_features = vgg(target)total_step = 2000
style_weight = 100.
for step in range(total_step):target_features = vgg(target)content_features = vgg(content)style_features = vgg(style)style_loss = 0content_loss = 0for f1, f2, f3 in zip(target_features, content_features, style_features):content_loss += torch.mean((f1 - f2) ** 2) # 使用L2 loss_, c, h, w = f1.size()f1 = f1.view(c, h * w)f3 = f3.view(c, h * w)# 计算gram matrixf1 = torch.mm(f1, f1.t()) # 此处做了点积运算,(c, h * w) × (h * w, c) = (c,c)f3 = torch.mm(f3, f3.t())style_loss += torch.mean((f1 - f3) ** 2) / (c * h * w) # 使用L2 loss,只是多个了除数而已# 由于损失是不同的,我们将内容和风格损失的和作为总的loss,此处我们给他们各自合适的权重,会让总的loss看起来更符合真实的lossloss = content_loss + style_weight * style_loss# 更新target image 的 Tensoroptimizer.zero_grad()loss.backward()optimizer.step()if step % 10 == 0:print("Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}".format(step, total_step, content_loss.item(), style_loss.item()))# 将图片打印出来
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
img = target.clone().squeeze()
img = denorm(img).clamp_(0, 1)
plt.figure() # Create a new figure, or activate an existing figure.
imshow(img, title='Target Image')
content.png style.png
target .png
【6.1】图片风格迁移 Neural Style Transfer相关推荐
- 吴恩达老师深度学习视频课笔记:神经风格迁移(neural style transfer)
什么是神经风格迁移(neural style transfer):如下图,Content为原始拍摄的图像,Style为一种风格图像.如果用Style来重新创造Content照片,神经风 ...
- 图像风格迁移(Neural Style)简史
图像风格迁移科技树 什么是图像风格迁移? 先上一组图. 以下每一张图都是一种不同的艺术风格.作为非艺术专业的人,我就不扯艺术风格是什么了,每个人都有每个人的见解,有些东西大概艺术界也没明确的定义.如 ...
- Pytorch 风格迁移(Style transfer)
Pytorch 风格迁移 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 Shift ...
- 图像迁移风格保存模型_CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介、关键步骤配图、案例应用...
CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介.过程思路.关键步骤配图.案例应用之详细攻略 目录 图像风格迁移算法简介 图像风格迁移算法过程思路 1.VGG对比NS 图像风 ...
- CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介、过程思路、关键步骤配图、案例应用之详细攻略
CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介.过程思路.关键步骤配图.案例应用之详细攻略 目录 图像风格迁移算法简介 图像风格迁移算法过程思路 1.VGG对比NS 图像风 ...
- CNN实现图像风格迁移 ---Image Style Transfer Using Convolutional Neural Networks
目录 1. INTRODUCTION 2. Deep image representations 2.1 内容表示 2.2. Style representation 2.3 风格迁移 3. Re ...
- 风格迁移(Style Transfer)首次学习总结
0.写在前面 最近看了吴恩达老师风格迁移相关的讲解视频,深受启发,于是想着做做总结. 1.主要思想 目的:把一张内容图片(content image)的风格迁移成与另一张图片(style image) ...
- 第六节 图片风格迁移和GAN
第六节 图片风格迁移 - 图片风格迁移 - 用GAN生成MNIST - 用DCGAN生成更复杂的图片## 图片风格迁移 Neural Style Transfer matplotlib inlinef ...
- 神经风格迁移(Neural Style Transfer)程序实现(Caffe)
前言 上次的博客写了神经风格迁移(Neural Style Transfer)程序实现(Keras),使用keras的一个好处就是api简单,能够快速部署模型,使用很方便.出于学习目的,这次又使用ca ...
最新文章
- 【CVPR2020 Oral】只需一行代码就可提升迁移性能
- J-Link驱动下载和JLINK下载Hex程序
- 软件级负载均衡器(LVS/HAProxy/LVS)的特点简介和对比
- poj 1018 Communication System
- [html] 元素的alt和title有什么区别?
- flex gallery / 产品展示
- node 修改文件自启动
- 苹果手机上网速度慢_手机信号明明满格却上不去网?4招帮你搞定它!
- ios客户端快速滚动和回弹效果的实现
- react-custom-scrollbars滚动组件
- 利用U盘安装win2008r2系统的步骤
- 系统管理Lesson 14. Performing Database Backups
- UEFI Boot Flow 系列之 SEC Phase
- EditPlus使用技巧集
- java文件读取报(文件名、目录名或卷标语法不正确。)
- Visual Studio Code插件-前端工程师开发必备
- 读《金刚经》学心态,读《易经》学生存,读《道德经》学生活
- 360浏览器用地址栏搜索怎么更改搜索引擎
- [转]一往无前 | 小米十周年,雷军公开演讲全文
- java nodelist.item_XPath NodeList顺序(Java)
热门文章
- Redis分布式基础主从同步
- 单元测试:unittest.TestCase
- MIT科学家正在教AI感受电影中的喜怒哀乐
- 工信部:加强中欧在5G、物联网等领域合作
- 8分钟回顾开源巨头 Facebook 的 2016
- mybatis 多对多 处理
- DedeCMS四类核心表
- Python 字符串操作(string替换、删除、截取、复制、连接、比较、查找、包含、大小写转......
- 老旗舰华为能用上鸿蒙吗,华为完全开放鸿蒙,未来所有手机都能用鸿蒙系统?...
- 怎么把线稿提取出来_如何快速提取漫画线稿?【漫画技巧】