完整代码: 

from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import numpy as npimport matplotlib.pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# -------------------------------------------------加载图片---------------------------------------------------------------
# 加载图片 + 图像预处理为相同的shape,这样vgg提取出来feature vector才是一样的大小,否则不能直接计算L2 loss
def load_image(image_path, transform=None, max_size=None, shape=None):image = Image.open(image_path)if max_size:scale = max_size / max(image.size)size = np.array(image.size) * scaleimage = image.resize(size.astype(int), Image.ANTIALIAS)if shape:image = image.resize(shape, Image.LANCZOS)if transform:image = transform(image).unsqueeze(0)return image.to(device)transform = transforms.Compose([transforms.ToTensor(),# 因为我要用VGG,在ImageNet上做的处理,transforms.Normalize(mean=[0.485, 0.456, 0.406],std= [0.229, 0.224, 0.225])
])  # 来自ImageNet的mean和variance# 此处是经过标准化后的照片
content = load_image("png/content.png", transform, max_size=400)
style = load_image("png/style.png", transform, shape=[content.size(2), content.size(3)]) # 这里是想得到和content大小一样的样式print(content.shape,style.shape) # torch.Size([1, 3, 400, 311]) torch.Size([1, 3, 311, 400])# -------------------------------------------------图片展示给大家看---------------------------------------------------------
unloader = transforms.ToPILImage()  # reconvert into PIL image
plt.ion()
def imshow(tensor, title=None):image = tensor.cpu().clone()  # we clone the tensor to not do changes on itimage = image.squeeze(0)      # remove the fake batch dimensionimage = unloader(image)plt.imshow(image)if title is not None:plt.title(title)plt.pause(0.001) # pause a bit so that plots are updatedplt.figure()
imshow(style[0], title='Image')# 并不是训练这个VGGNet,他只是一个特征提取器,真正要优化的是这一张target图片
class VGGNet(nn.Module):def __init__(self):super(VGGNet, self).__init__()# 有些层取出来当feature,基本上就可以拿到图片的内容和textureself.select = ['0', '5', '10', '19', '28']# 拿到VGG network,此时我们只需要features部分(即只需要拿到这些层就可以了,其余的信息不需要)self.vgg = models.vgg19(pretrained=True).features# 取出 self.select 中的层,组成一个新的featuresdef forward(self, x):features = []for name, layer in self.vgg._modules.items(): # _modules 可以把vgg一层层拿出来x = layer(x)if name in self.select:features.append(x)return features# target拿到的就是和content.png内容上相似,但是风格上更倾向于style.png的图片,内容是会变化的所以requires_grad_
target = content.clone().requires_grad_(True)
# 优化的是target这张图片,
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()  # 所以设置为eval(),它是不会被优化的# 打印出每一层拿到的 feature vector
# torch.Size([1, 64, 400, 311])
# torch.Size([1, 128, 200, 155])
# torch.Size([1, 256, 100, 77])
# torch.Size([1, 512, 50, 38])
# torch.Size([1, 512, 25, 19])
feature = vgg(content)  # list,里面包含不同的features
for feat in feature:print(feat.shape)# 开始优化我们的target图片
target_features = vgg(target)total_step = 2000
style_weight = 100.
for step in range(total_step):target_features  = vgg(target)content_features = vgg(content)style_features   = vgg(style)style_loss = 0content_loss = 0for f1, f2, f3 in zip(target_features, content_features, style_features):content_loss += torch.mean((f1 - f2) ** 2)  # 使用L2 loss_, c, h, w = f1.size()f1 = f1.view(c, h * w)f3 = f3.view(c, h * w)# 计算gram matrixf1 = torch.mm(f1, f1.t())  # 此处做了点积运算,(c, h * w) × (h * w, c) = (c,c)f3 = torch.mm(f3, f3.t())style_loss += torch.mean((f1 - f3) ** 2) / (c * h * w) # 使用L2 loss,只是多个了除数而已# 由于损失是不同的,我们将内容和风格损失的和作为总的loss,此处我们给他们各自合适的权重,会让总的loss看起来更符合真实的lossloss = content_loss + style_weight * style_loss# 更新target image 的 Tensoroptimizer.zero_grad()loss.backward()optimizer.step()if step % 10 == 0:print("Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}".format(step, total_step, content_loss.item(), style_loss.item()))# 将图片打印出来
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
img = target.clone().squeeze()
img = denorm(img).clamp_(0, 1)
plt.figure() # Create a new figure, or activate an existing figure.
imshow(img, title='Target Image')

   

content.png                                                   style.png

target .png

【6.1】图片风格迁移 Neural Style Transfer相关推荐

  1. 吴恩达老师深度学习视频课笔记:神经风格迁移(neural style transfer)

            什么是神经风格迁移(neural style transfer):如下图,Content为原始拍摄的图像,Style为一种风格图像.如果用Style来重新创造Content照片,神经风 ...

  2. 图像风格迁移(Neural Style)简史

     图像风格迁移科技树 什么是图像风格迁移? 先上一组图. 以下每一张图都是一种不同的艺术风格.作为非艺术专业的人,我就不扯艺术风格是什么了,每个人都有每个人的见解,有些东西大概艺术界也没明确的定义.如 ...

  3. Pytorch 风格迁移(Style transfer)

    Pytorch 风格迁移 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 Shift ...

  4. 图像迁移风格保存模型_CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介、关键步骤配图、案例应用...

    CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介.过程思路.关键步骤配图.案例应用之详细攻略 目录 图像风格迁移算法简介 图像风格迁移算法过程思路 1.VGG对比NS 图像风 ...

  5. CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介、过程思路、关键步骤配图、案例应用之详细攻略

    CV之NS:图像风格迁移(Neural Style 图像风格变换)算法简介.过程思路.关键步骤配图.案例应用之详细攻略 目录 图像风格迁移算法简介 图像风格迁移算法过程思路 1.VGG对比NS 图像风 ...

  6. CNN实现图像风格迁移 ---Image Style Transfer Using Convolutional Neural Networks

    目录 1. INTRODUCTION 2. Deep image representations 2.1  内容表示 2.2. Style representation 2.3  风格迁移 3. Re ...

  7. 风格迁移(Style Transfer)首次学习总结

    0.写在前面 最近看了吴恩达老师风格迁移相关的讲解视频,深受启发,于是想着做做总结. 1.主要思想 目的:把一张内容图片(content image)的风格迁移成与另一张图片(style image) ...

  8. 第六节 图片风格迁移和GAN

    第六节 图片风格迁移 - 图片风格迁移 - 用GAN生成MNIST - 用DCGAN生成更复杂的图片## 图片风格迁移 Neural Style Transfer matplotlib inlinef ...

  9. 神经风格迁移(Neural Style Transfer)程序实现(Caffe)

    前言 上次的博客写了神经风格迁移(Neural Style Transfer)程序实现(Keras),使用keras的一个好处就是api简单,能够快速部署模型,使用很方便.出于学习目的,这次又使用ca ...

最新文章

  1. 【CVPR2020 Oral】只需一行代码就可提升迁移性能
  2. J-Link驱动下载和JLINK下载Hex程序
  3. 软件级负载均衡器(LVS/HAProxy/LVS)的特点简介和对比
  4. poj 1018 Communication System
  5. [html] 元素的alt和title有什么区别?
  6. flex gallery / 产品展示
  7. node 修改文件自启动
  8. 苹果手机上网速度慢_手机信号明明满格却上不去网?4招帮你搞定它!
  9. ios客户端快速滚动和回弹效果的实现
  10. react-custom-scrollbars滚动组件
  11. 利用U盘安装win2008r2系统的步骤
  12. 系统管理Lesson 14. Performing Database Backups
  13. UEFI Boot Flow 系列之 SEC Phase
  14. EditPlus使用技巧集
  15. java文件读取报(文件名、目录名或卷标语法不正确。)
  16. Visual Studio Code插件-前端工程师开发必备
  17. 读《金刚经》学心态,读《易经》学生存,读《道德经》学生活
  18. 360浏览器用地址栏搜索怎么更改搜索引擎
  19. [转]一往无前 | 小米十周年,雷军公开演讲全文
  20. java nodelist.item_XPath NodeList顺序(Java)

热门文章

  1. Redis分布式基础主从同步
  2. 单元测试:unittest.TestCase
  3. MIT科学家正在教AI感受电影中的喜怒哀乐
  4. 工信部:加强中欧在5G、物联网等领域合作
  5. 8分钟回顾开源巨头 Facebook 的 2016
  6. mybatis 多对多 处理
  7. DedeCMS四类核心表
  8. Python 字符串操作(string替换、删除、截取、复制、连接、比较、查找、包含、大小写转......
  9. 老旗舰华为能用上鸿蒙吗,华为完全开放鸿蒙,未来所有手机都能用鸿蒙系统?...
  10. 怎么把线稿提取出来_如何快速提取漫画线稿?【漫画技巧】