一. 新的东西
p.s :很多架构都和之前一样,就举些不同的

1. ReplayBuffer()
# Buffers of previously generated samples
 
fake_A_buffer = ReplayBuffer()
 
fake_B_buffer = ReplayBuffer()
这是什么??看看utils.py中的

class ReplayBuffer():
 
    def __init__(self, max_size=50):
 
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
 
        self.max_size = max_size
 
        self.data = []
 
 
    def push_and_pop(self, data):
 
        to_return = []
 
        for element in data.data:
 
            element = torch.unsqueeze(element, 0)
 
            if len(self.data) < self.max_size:
 
                self.data.append(element)
 
                to_return.append(element)
 
            else:
 
                if random.uniform(0,1) > 0.5:
 
                    i = random.randint(0, self.max_size-1)
 
                    to_return.append(self.data[i].clone())
 
                    self.data[i] = element
 
                else:
 
                    to_return.append(element)
 
        return Variable(torch.cat(to_return))
push and pop这是buffer的进栈和入栈?先理解为为了训练稳定把。以后遇到再补充,网上资源好少。

补充:

通过对训练过程的学习发现,生成器生成的 fake 图片还要经过另一生成器,生成 cycle 图片

所以通过该buffer函数寄存 fake 图片,用于判别器更新?算有了点感性认识把。

2. patchGAN
# Calculate output of image discriminator (PatchGAN)
 
patch = (1, opt.img_height // 2**4, opt.img_width // 2**4)
p.s: patch大小为16x16,前一篇也用到了,网上资源也好少,举下大神们的见解

(pix2pix : pix2pix和SRGAN的一个异曲同工的地方是都有用重建解决低频成分,用GAN解决高频成分的想法。在pix2pix中,这个思想主要体现在两个地方。一个是loss函数,加入了L1 loss用来让生成的图片和训练的目标图片尽量相似,而图像中高频的细节部分则交由GAN来处理:

还有一个就是PatchGAN,也就是具体的GAN中用来判别是否生成图的方法。PatchGAN的思想是,既然GAN只负责处理低频成分,那么判别器就没必要以一整张图作为输入,只需要对NxN的一个图像patch去进行判别就可以了。这也是为什么叫Markovian discriminator,因为在patch以外的部分认为和本patch互相独立。

具体实现的时候,作者使用的是一个NxN输入的全卷积小网络,最后一层每个像素过sigmoid输出为真的概率,然后用BCEloss计算得到最终loss。这样做的好处是因为输入的维度大大降低,所以参数量少,运算速度也比直接输入一张快,并且可以计算任意大小的图。作者对比了不同大小patch的结果,对于256x256的输入,patch大小在70x70的时候,从视觉上看结果就和直接把整张图片作为判别器输入没什么区别了)

(字体MC-GAN: Discriminator 引用了 PatchGAN [1]的思想,即在公共网络加了 3 层卷积层采用了 21 × 21 Local Discriminator 去衡量局部真假,然后又在公共网络上平行加了 2 层作为 Global Discriminator 去衡量整个图片的真假。)

数字图像处理?低频部分L1loss, 高频部分用patchGAN。

以下是本篇关于patch部分的代码

# Adversarial ground truths
 
valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
 
fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)
#...
 
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
对比一下没用patch的某篇文章的代码

valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
#...
 
d_real_loss = adversarial_loss(validity_real, valid)
嗯哼,原来是将valid改成16x16的去计算D的loss

二. Models
1. GeneratorResNet
生成器采用ResNet

##############################
 
#           RESNET
 
##############################
 
 
class ResidualBlock(nn.Module):
 
    def __init__(self, in_features):
 
        super(ResidualBlock, self).__init__()
 
 
        conv_block = [  nn.ReflectionPad2d(1),
 
                        nn.Conv2d(in_features, in_features, 3),
 
                        nn.InstanceNorm2d(in_features),
 
                        nn.ReLU(inplace=True),
 
                        nn.ReflectionPad2d(1),
 
                        nn.Conv2d(in_features, in_features, 3),
 
                        nn.InstanceNorm2d(in_features)  ]
 
 
        self.conv_block = nn.Sequential(*conv_block)
 
    def forward(self, x):
 
        return x + self.conv_block(x)
 
 
class GeneratorResNet(nn.Module):
 
    def __init__(self, in_channels=3, out_channels=3, res_blocks=9):
 
        super(GeneratorResNet, self).__init__()
 
 
        # Initial convolution block
 
        model = [   nn.ReflectionPad2d(3),
 
                    nn.Conv2d(in_channels, 64, 7),
 
                    nn.InstanceNorm2d(64),
 
                    nn.ReLU(inplace=True) ]
 
 
        # Downsampling
 
        in_features = 64
 
        out_features = in_features*2
 
        for _ in range(2):
 
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
 
                        nn.InstanceNorm2d(out_features),
 
                        nn.ReLU(inplace=True) ]
 
            in_features = out_features
 
            out_features = in_features*2
 
 
        # Residual blocks
 
        for _ in range(res_blocks):
 
            model += [ResidualBlock(in_features)]
 
 
        # Upsampling
 
        out_features = in_features//2
 
        for _ in range(2):
 
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
 
                        nn.InstanceNorm2d(out_features),
 
                        nn.ReLU(inplace=True) ]
 
            in_features = out_features
 
            out_features = in_features//2
 
 
        # Output layer
 
        model += [  nn.ReflectionPad2d(3),
 
                    nn.Conv2d(64, out_channels, 7),
 
                    nn.Tanh() ]
 
 
        self.model = nn.Sequential(*model)
 
 
    def forward(self, x):
 
        return self.model(x)
可视化:

2. Discriminator
##############################
#        Discriminator
##############################
 
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
 
        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
 
        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
 
    def forward(self, img):
        return self.model(img)
源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cyclegan
--------------------- 
作者:眉间细雪 
来源:CSDN 
原文:https://blog.csdn.net/weixin_42445501/article/details/81234281 
版权声明:本文为博主原创文章,转载请附上博文链接!

pytorch cycleGAN代码学习1相关推荐

  1. Pytorch:CycleGAN代码中nn.Sequential(*module)处错误:list is not a Module subclass

    问题: pytorch训练模型时,将模型的层layer都放到module列表中: module = [layer1,layer2,layer3....] nn.Sequential(*module) ...

  2. 【深度学习】PyTorch常用代码段合集

    来源 | 极市平台,机器学习算法与自然语言处理 本文是PyTorch常用代码段合集,涵盖基本配置.张量处理.模型定义与操作.数据处理.模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常 ...

  3. PyTorch官方教程中文版:入门强化教程代码学习

    PyTorch之数据加载和处理 from __future__ import print_function, division import os import torch import pandas ...

  4. (Pytorch)环境配置与代码学习1—边缘检测:更丰富的卷积特征 Richer Convolutional Features for Edge Detection

    (Pytorch)环境配置与代码学习1 - 边缘检测:更丰富的卷积特征 Richer Convolutional Features for Edge Detection Source code and ...

  5. [pytorch] PyTorch Metric Learning库代码学习二 Inference

    PyTorch Metric Learning库代码学习二 Inference Install the packages Import the packages Create helper funct ...

  6. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  7. 收藏!PyTorch常用代码段合集

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:Jack Stark,来源:极市平台 来源丨https://zhu ...

  8. PyTorch核心贡献者开源书:《使用PyTorch进行深度学习》完整版现已发布!

    来源|新智元 [导读]<使用PyTorch进行深度学习>一书的完整版现已发布!教你如何使用PyTorch创建神经网络和深度学习系统,内含图解与代码,操作易上手. 由Luca Antiga. ...

  9. PyTorch常用代码段整理合集,建议收藏!

    点击上方,选择星标或置顶,每天给你送干! 阅读大概需要12分钟 跟随小博主,每天进步一丢丢 张皓:南京大学计算机系机器学习与数据挖掘所(LAMDA)硕士生,研究方向为计算机视觉和机器学习,特别是视觉识 ...

最新文章

  1. MySQL查询出错提示 --secure-file-priv解决方法
  2. STM32 电机教程 12 - BLDC 闭环电流控制
  3. MySQL Cluster 配置详细介绍
  4. 我的控制反转,依赖注入和面向切面编程的理解
  5. 编写下载服务器。 第三部分:标头:内容长度和范围
  6. Mybatis四种分页方式
  7. 上学与不上学的区别_这是我在全球最大的React会议上学到的
  8. 《数字图像处理》冈萨雷斯版 读书笔记(一)
  9. 拼多多商品按关键词采集爆款商品
  10. 学会充分利用你的零碎时间
  11. 如何防范动态调试(Anti-Debug)(SoftICE篇)
  12. 【网络安全系列】之新型勒索病毒WannaRen疑在国内大规模传播,威力不亚于新冠
  13. 错误处理(包括日志记录)
  14. 机器学习数据预处理之离群值/异常值:标准差法
  15. DPDK:UDP 协议栈的实现
  16. Java es should_@Es问题--should和must同时使用
  17. 黑马培训教学SSM整合中Security遇到的问题org.springframework.security.access.AccessDeniedException: Access is denied
  18. java 读取dwg_jdwglib java dwg文件的读取,写入开发包. dwg使用当前 常方便,测试代码和jar都有 CAD 247万源代码下载- www.pudn.com...
  19. k8s使用(kubernetes)
  20. 【有利可图网】PS实战教程49:PS滤镜调色方法之调出古韵金色花朵图片

热门文章

  1. 计算机科学领域最高荣誉,骄傲!这位毕业于嘉兴一中的数学家,荣获华人数学领域的最高荣誉...
  2. java集合 stack_Java集合之Stack
  3. c++ class struct同名_第二课C到C++的关系
  4. 利用循环,使得10 * 10的二维数组具有以下值,并按以下结构输出在屏幕上
  5. Powershell(3)
  6. 一机玩转docker之十:创建及使用ssh镜像
  7. Java Spring的IoC和AOP的知识点速记
  8. LUA C 交互 cocos
  9. fatal error: Python.h: No such file or directory 解决
  10. NGINX由入门到精通:Nginx介绍