1、摘要

本文主要讲解:SinGAN-一张照片即可生成一模一样的照片(附简化版代码)
主要思路:

  1. 先由一个Z_N输入到G_N的生成器得到生成图像(这一步是单纯由噪声生成,其他生成器的输入都是由随机噪声图像z_n和上一层生成的 上采样到当前生成器尺寸组成)。
  2. 接着利用生成图像的图像块(每一层图像块的大小不一样,按照由粗糙到精细、由大到小)和当前层的图像块(由训练数据下采样得到)放入判别器中进行判断,直到两者不能被判别器区分。
  3. 通过这种一层一层、由下往上的训练过程,得到最终的结果。

2、相关技术

SinGAN架构
一种基于层级的patch-GAN模型(Markovian discriminator)。如下图所示,模型的每个部分负责输入图像的不同尺度捕获图像块分布。这种层级GAN模型感受野小和有限的功能,可以防止网络记住整图的信息。虽然类似的网络结构被应用过,但这是首次应用在一张图像的内部学习上。

模型是由金字塔形式大小的生成器 组成,训练数据 也是金字塔形式大小组成,训练数据是由一个 因子控制,一些r>0。根据每层 的图像块分布,相应层的生成器 产生真实的图像实例。然后通过对抗学习,判别器 通过对生成器 产生的图像块(生成图像的某一部分)进行判别,达到相对较好的状态(以目前来说达不到最终的纳什均衡点),最后完成训练过程。

从刚刚的图中我们可以看到,每个尺度注入噪声后,先由粗糙的尺度开始生成图像,然后按照相应的顺序传递到相对应的生成器,最终生成精细的尺度;某一层的所有生成器和判别器有着相同的感受野,随着由下往上的生成过程,因此可以捕获尺度减小的结构信息。

3、完整代码和步骤

算法训练的效果如此视频:

SinGAN训练过程

主运行程序入口

import os
from SinGAN.run_train import functions
from SinGAN.run_train.manipulate import SinGAN_generate
from SinGAN.run_train.training import train
from SinGAN.run_train.config import get_argumentsif __name__ == '__main__':parser = get_arguments()parser.add_argument('--input_dir', help='input image dir', default='../Input/Images')parser.add_argument('--input_name', help='input image name', default='food.jpg')parser.add_argument('--mode', help='task to be done', default='train')opt = parser.parse_args()#opt = functions.post_config(opt)Gs = []Zs = []reals = []NoiseAmp = []dir2save = functions.generate_dir2save(opt)if (os.path.exists(dir2save)):print('trained model already exist')else:try:os.makedirs(dir2save)except OSError:pass# 将图片读取成torch版的数据real = functions.read_image(opt)# 将图片适配尺寸functions.adjust_scales2image(real, opt)# 开始训练模型 opt 手动输入的参数train(opt, Gs, Zs, reals, NoiseAmp)# 根据模型生成图片  生成具有任意大小和比例的新图像SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt)

training.py

   import os
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import math
import matplotlib.pyplot as pltfrom SinGAN.run_train import functions, models
from SinGAN.run_train.imresize import imresizedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def train(opt, Gs, Zs, reals, NoiseAmp):real_ = functions.read_image(opt)in_s = 0scale_num = 0# 计算局部权重 调整大小real = imresize(real_, opt.scale1, opt)# 创造真实图片的锥体reals = functions.creat_reals_pyramid(real, reals, opt)nfc_prev = 0# 全卷积GANs组成的金字塔while scale_num < opt.stop_scale + 1:opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)opt.out_ = functions.generate_dir2save(opt)opt.outf = '%s/%d' % (opt.out_, scale_num)try:os.makedirs(opt.outf)except OSError:passplt.imsave('%s/in.png' % (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)plt.imsave('%s/original.png' % (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)D_curr, G_curr = init_models(opt)if (nfc_prev == opt.nfc):# 加载训练好的模型G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1)))D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1)))z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt)# 是否固定部分参数进行网络训练G_curr = functions.reset_grads(G_curr, False)G_curr.eval()D_curr = functions.reset_grads(D_curr, False)D_curr.eval()Gs.append(G_curr)Zs.append(z_curr)NoiseAmp.append(opt.noise_amp)torch.save(Zs, '%s/Zs.pth' % (opt.out_))torch.save(Gs, '%s/Gs.pth' % (opt.out_))torch.save(reals, '%s/reals.pth' % (opt.out_))torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))scale_num += 1nfc_prev = opt.nfcdel D_curr, G_currreturndef train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None):real = reals[len(Gs)]opt.nzx = real.shape[2]  # +(opt.ker_size-1)*(opt.num_layer)opt.nzy = real.shape[3]  # +(opt.ker_size-1)*(opt.num_layer)opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stridepad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)if opt.mode == 'animation_train':opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)pad_noise = 0#     对Tensor使用0进行边界填充m_noise = nn.ZeroPad2d(int(pad_noise))m_image = nn.ZeroPad2d(int(pad_image))alpha = opt.alphafixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=device)# 返回一个大小为fill_value的张量z_opt = torch.full(fixed_noise.shape, 0, device=device)z_opt = m_noise(z_opt)# setup optimizeroptimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999))optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999))# 按需调整学习率schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma)schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma)errD2plot = []errG2plot = []D_real2plot = []D_fake2plot = []z_opt2plot = []# 它是从噪声生成图像的for epoch in range(opt.niter):if (Gs == []) & (opt.mode != 'SR_train'):z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=device)z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=device)noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))else:noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=device)noise_ = m_noise(noise_)############################# (1) Update D network: maximize D(x) + D(G(z))############################ Dsteps 'Discriminator inner steps',default=3for j in range(opt.Dsteps):# train with realnetD.zero_grad()output = netD(real).to(device)# D_real_map = output.detach()errD_real = -output.mean()  # -aerrD_real.backward(retain_graph=True)D_x = -errD_real.item()# train with fakeif (j == 0) & (epoch == 0):if (Gs == []) & (opt.mode != 'SR_train'):prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=device)in_s = prevprev = m_image(prev)z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=device)z_prev = m_noise(z_prev)opt.noise_amp = 1elif opt.mode == 'SR_train':z_prev = in_scriterion = nn.MSELoss()RMSE = torch.sqrt(criterion(real, z_prev))opt.noise_amp = opt.noise_amp_init * RMSEz_prev = m_image(z_prev)prev = z_prevelse:prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt)prev = m_image(prev)z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt)criterion = nn.MSELoss()RMSE = torch.sqrt(criterion(real, z_prev))opt.noise_amp = opt.noise_amp_init * RMSEz_prev = m_image(z_prev)else:prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt)prev = m_image(prev)if opt.mode == 'paint_train':prev = functions.quant2centers(prev, centers)plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)if (Gs == []) & (opt.mode != 'SR_train'):noise = noise_else:noise = opt.noise_amp * noise_ + prevfake = netG(noise.detach(), prev)output = netD(fake.detach())errD_fake = output.mean()errD_fake.backward(retain_graph=True)D_G_z = output.mean().item()gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, device)gradient_penalty.backward()errD = errD_real + errD_fake + gradient_penaltyoptimizerD.step()errD2plot.append(errD.detach())############################# (2) Update G network: 最大化 D(G(z))###########################for j in range(opt.Gsteps):netG.zero_grad()output = netD(fake)D_fake_map = output.detach()errG = -output.mean()# errG.backward(retain_graph=True)if alpha != 0:loss = nn.MSELoss()if opt.mode == 'paint_train':z_prev = functions.quant2centers(z_prev, centers)plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)Z_opt = opt.noise_amp * z_opt + z_prevrec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)rec_loss.backward(retain_graph=True)rec_loss = rec_loss.detach()else:Z_opt = z_optrec_loss = 0optimizerG.step()errG2plot.append(errG.detach() + rec_loss)D_real2plot.append(D_x)D_fake2plot.append(D_G_z)z_opt2plot.append(rec_loss)if epoch % 25 == 0 or epoch == (opt.niter - 1):print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))if epoch % 500 == 0 or epoch == (opt.niter - 1):plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1)plt.imsave('%s/G(z_opt).png' % (opt.outf),functions.convert_image_np(netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1)# plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))# plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))# plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)# plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)# plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)# plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))schedulerD.step()schedulerG.step()functions.save_networks(netG, netD, z_opt, opt)return z_opt, in_s, netGdef draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt):G_z = in_sif len(Gs) > 0:if mode == 'rand':count = 0pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)if opt.mode == 'animation_train':pad_noise = 0for G, Z_opt, real_curr, real_next, noise_amp in zip(Gs, Zs, reals, reals[1:], NoiseAmp):if count == 0:z = functions.generate_noise([1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise],device=device)z = z.expand(1, 3, z.shape[2], z.shape[3])else:z = functions.generate_noise([opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=device)z = m_noise(z)G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]G_z = m_image(G_z)z_in = noise_amp * z + G_zG_z = G(z_in.detach(), G_z)G_z = imresize(G_z, 1 / opt.scale_factor, opt)G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]count += 1if mode == 'rec':count = 0for G, Z_opt, real_curr, real_next, noise_amp in zip(Gs, Zs, reals, reals[1:], NoiseAmp):G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]]G_z = m_image(G_z)z_in = noise_amp * Z_opt + G_zG_z = G(z_in.detach(), G_z)G_z = imresize(G_z, 1 / opt.scale_factor, opt)G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]]# if count != (len(Gs)-1):#    G_z = m_image(G_z)count += 1return G_zdef train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, paint_inject_scale):in_s = torch.full(reals[0].shape, 0, device=device)scale_num = 0nfc_prev = 0while scale_num < opt.stop_scale + 1:if scale_num != paint_inject_scale:scale_num += 1nfc_prev = opt.nfccontinueelse:opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128)opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128)opt.out_ = functions.generate_dir2save(opt)opt.outf = '%s/%d' % (opt.out_, scale_num)try:os.makedirs(opt.outf)except OSError:pass# plt.imsave('%s/in.png' %  (opt.out_), functions.convert_image_np(real), vmin=0, vmax=1)# plt.imsave('%s/original.png' %  (opt.out_), functions.convert_image_np(real_), vmin=0, vmax=1)plt.imsave('%s/in_scale.png' % (opt.outf), functions.convert_image_np(reals[scale_num]), vmin=0, vmax=1)D_curr, G_curr = init_models(opt)z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals[:scale_num + 1], Gs[:scale_num],Zs[:scale_num], in_s, NoiseAmp[:scale_num], opt, centers=centers)G_curr = functions.reset_grads(G_curr, False)G_curr.eval()D_curr = functions.reset_grads(D_curr, False)D_curr.eval()Gs[scale_num] = G_currZs[scale_num] = z_currNoiseAmp[scale_num] = opt.noise_amptorch.save(Zs, '%s/Zs.pth' % (opt.out_))torch.save(Gs, '%s/Gs.pth' % (opt.out_))torch.save(reals, '%s/reals.pth' % (opt.out_))torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_))scale_num += 1nfc_prev = opt.nfcdel D_curr, G_currreturndef init_models(opt):# generator initialization:netG = models.GeneratorConcatSkip2CleanAdd(opt).to(device)netG.apply(models.weights_init)if opt.netG != '':# 加载模型netG.load_state_dict(torch.load(opt.netG))print(netG)# discriminator initialization:netD = models.WDiscriminator(opt).to(device)netD.apply(models.weights_init)if opt.netD != '':netD.load_state_dict(torch.load(opt.netD))print(netD)return netD, netG

篇幅有限,本文仅展示部分代码

完整简化版代码

4、学习链接

从一张风景照中就学会的SinGAN模型,究竟是什么神操作?| ICCV 2019最佳论文

单张图像就可以训练GAN!Adobe改良图像生成方法(ConSinGAN)|已开源

github.com/tamarott/SinGAN

SinGAN一张照片即可生成同样的照片(附简化版代码)相关推荐

  1. 一张照片快速生成1寸照片,小白可用

    talk is cheep,show you the code import matplotlib.pyplot as plt import matplotlib.image as mpimg fro ...

  2. GAN二次元头像生成Pytorch实现(附完整代码)

    介绍 本文是李宏毅GAN课程课后作业HW3_1(二次元头像生成,Keras实现)的Pytorch版本.写这篇的原因是一方面刚开始接触GAN,二是个人比较习惯用Pytorch,所以将keras改成Pyt ...

  3. 基于Word2vec加TextRank算法生成中文新闻摘要(附python代码)

    转自 # https://blog.csdn.net/qq_36910634/article/details/97764251 import numpy as np import pandas as ...

  4. 【湍流】基于kolmogorov结合次谐波补偿的方法生成大气湍流相位屏附matlab代码

    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.

  5. PhotoScan集群,空三加密导入CC,正射影像生成及拼接(附航测练习数据)

    这次PhotoScan的教程涉及集群,空三导入CCC(Smart3D),以及正射和拼接流程. PhotoScan在1.5.0之后改名为Metashape,除名称改变外,操作基本相同.具体的软件介绍可看 ...

  6. 一键生成?从照片生成人脸 3D 模型 #AvatarMe

    Mixlab 制作数字人的工具非常多,2D 类制作工具包括 Live2D.Vroid Studio:超写实类制作工具包括 Avatary.MetaHuman Creator 等. 知识库 除了使用设计 ...

  7. 免费开源的高精度OCR文本提取,支持 100 多种语言、自动文本定位和脚本检测,几行代码即可实现离线使用(附源码)

    免费开源的高精度OCR文本提取,支持 100 多种语言.自动文本定位和脚本检测,几行代码即可实现离线使用(附源码). 要从图像.照片中提取文本吗?是否刚刚拍了讲义的照片并想将其转换为文本?那么您将需要 ...

  8. unity mysql生成cexcel_【C#附源码】数据库文档生成工具支持(Excel+Html)

    [2015] 很多时候,我们在生成数据库文档时,使用某些工具,可效果总不理想,不是内容不详细,就是表现效果一般般.很多还是word.html的.看着真是别扭.本人习惯用Excel,所以闲暇时,就简单的 ...

  9. python绘制球体_OpenGL三维球体数据生成与绘制【附源码】

    OpenGL三维球体数据生成与绘制源码: #include #include #include #include #include #include #include #include #includ ...

最新文章

  1. SVO学习笔记(二)
  2. POJ-1860-Currency Exchange
  3. Android Jetpack - Emoji表情符号初探
  4. Android之实现多张图片点击预览(支持放缩)和滑动
  5. Sky Line 与 ArcEngine的粘合剂 Composite UI AB?
  6. suse linux 查看cpu,Suse Linux zmd 耗用100% CPU
  7. PS教程第十五课:图层是最基本的要求
  8. 论文浅尝 | 面向单关系事实问题的中文问答模型
  9. 专家:不仅人脸识别,信息验证、生物特征都不可靠
  10. Android_自定义水波纹菜单弹出效果
  11. C++类的构造函数、析构函数与赋值函数
  12. win10修改计算机密码,教你如何更改win10系统电脑密码
  13. Android截屏工具类的使用
  14. 基于SSM实现微博系统
  15. android tif浏览器,简单三步评测几款TIFF格式图片浏览器
  16. 前端学习01 HTML入门
  17. 三分的多种写法及对应的精度 三分套三分原理
  18. 学习-Java继承和多态之方法重载
  19. 对折纸张(蓝桥杯真题)
  20. 去除office 正版验证

热门文章

  1. 仙游一中,郑毓煌:“演义”人生
  2. window lcd css,纯CSS实现液晶字体效果
  3. Google AdSense 申请通过技巧,西联汇款教程,西联汇款国内支持银行列表
  4. 单片机驱动继电器(光耦)
  5. 人工智能 Java 坦克机器人系列: 遗传算法
  6. D3D12渲染技术之创建和启用纹理
  7. 怎么用计算机求logo,pclogo小海龟里帮我设计一个复杂图
  8. 360加速球 android,Android加速球、360加速球
  9. 计算机技能标准包括哪些,(对口单招计算机技能考试标准.doc
  10. 神奇的代码系列(持续更新)