代码来源于Github中点赞最多的BigGAN复现

作者个人学习记录

BigGAN的生成器代码内部引用了代码人员编写的谱正则化(SN)以及批正则化(BN),关于这部分的解读地址在这里:
批正则化
谱正则化

关于生成器,对于生成图片的不同分辨率,代码人员提供了不同的参数,代码如下:

# Architectures for G
# Attention is passed in in the format '32_64' to mean applying an attention
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):# 字典 不同的分辨率对应不同的结构arch = {}arch[512] = {'in_channels' :  [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],'out_channels' : [ch * item for item in [16,  8, 8, 4, 2, 1, 1]],'upsample' : [True] * 7,'resolution' : [8, 16, 32, 64, 128, 256, 512],'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])for i in range(3,10)}}arch[256] = {'in_channels' :  [ch * item for item in [16, 16, 8, 8, 4, 2]],'out_channels' : [ch * item for item in [16,  8, 8, 4, 2, 1]],'upsample' : [True] * 6,'resolution' : [8, 16, 32, 64, 128, 256],'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])for i in range(3,9)}}arch[128] = {'in_channels' :  [ch * item for item in [16, 16, 8, 4, 2]],'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],'upsample' : [True] * 5,'resolution' : [8, 16, 32, 64, 128],'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])for i in range(3,8)}}arch[64]  = {'in_channels' :  [ch * item for item in [16, 16, 8, 4]],'out_channels' : [ch * item for item in [16, 8, 4, 2]],'upsample' : [True] * 4,'resolution' : [8, 16, 32, 64],'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])for i in range(3,7)}}arch[32]  = {'in_channels' :  [ch * item for item in [4, 4, 4]],'out_channels' : [ch * item for item in [4, 4, 4]],'upsample' : [True] * 3,'resolution' : [8, 16, 32],'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])for i in range(3,6)}}return arch

这段代码定义了一个函数,该函数返回一个字典,字典中的key代表生成图像的分辨率,value则代表不同分辨率所对应的参数,其中upsample代表是否进行上采样,attention代表在哪层使用注意力机制。这个字典将输入到生成器的类中,作为参数。
接下来为生成器部分:
在参数设置阶段代码如下

class Generator(nn.Module):def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,G_kernel_size=3, G_attn='64', n_classes=1000,num_G_SVs=1, num_G_SV_itrs=1,G_shared=True, shared_dim=0, hier=False,cross_replica=False, mybn=False,G_activation=nn.ReLU(inplace=False),G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,G_init='ortho', skip_init=False, no_optim=False,G_param='SN', norm_style='bn',**kwargs):super(Generator, self).__init__()# Channel width mulitplierself.ch = G_ch# Dimensionality of the latent spaceself.dim_z = dim_z# The initial spatial dimensionsself.bottom_width = bottom_width# Resolution of the outputself.resolution = resolution# Kernel size?self.kernel_size = G_kernel_size# Attention?self.attention = G_attn# number of classes, for use in categorical conditional generationself.n_classes = n_classes# Use shared embeddings?self.G_shared = G_shared# Dimensionality of the shared embedding? Unused if not using G_sharedself.shared_dim = shared_dim if shared_dim > 0 else dim_z# Hierarchical latent space?self.hier = hier# Cross replica batchnorm?self.cross_replica = cross_replica# Use my batchnorm?self.mybn = mybn# nonlinearity for residual blocksself.activation = G_activation# Initialization styleself.init = G_init# Parameterization styleself.G_param = G_param# Normalization styleself.norm_style = norm_style# Epsilon for BatchNorm?self.BN_eps = BN_eps# Epsilon for Spectral Norm?self.SN_eps = SN_eps# fp16?self.fp16 = G_fp16# Architecture dictself.arch = G_arch(self.ch, self.attention)[resolution]

在这里分别定义了再上一阶段引入的字典,决定生成图像的分辨率,卷积核的大小,类别数量、时候进行分层操作(有潜在空间),是否使用注意力机制,使用何种bn层、激活函数、标准化方法等等。
之后如果有潜在空间,则对噪声信号进行处理,因为在论文中提到,输入到生成器的噪声被分割然后与类别信息相组合,输入到不同的层中。代码如下:

    # If using hierarchical latents, adjust zif self.hier:# Number of places z slots into# 有潜在空间的情况下,这个应该就是指多层需要chunk的意思吧# 因为要将噪声z分割成多个chunk,这部分用于计算chunk的num与size,以及对应的噪声z维度self.num_slots = len(self.arch['in_channels']) + 1 # 层的数量加一self.z_chunk_size = (self.dim_z // self.num_slots)# Recalculate latent dimensionality for even splitting into chunksself.dim_z = self.z_chunk_size *  self.num_slotselse:# 只有一层需要chunk加入,就不需要分割z了self.num_slots = 1self.z_chunk_size = 0

如果有潜在空间,可以看到这里用到了字典中的value中的’in_channels’,代表每层使用的通道数,因为每层都要有被分割的噪声块(chunk)输入,并且在网络最开始也要输入噪声,所以输入的噪声被分割成为层数+1个chunk,确定数量之后就是确定维度。
如果没有潜在空间。就是在else之后,就有一块,也没有chunk了。

下一步为选择不同的程序组件,bn层、线性层、谱正则化、embedding这些,代码如下:

    # Which convs, batchnorms, and linear layers to use# 选择是否使用具有谱正则化的层if self.G_param == 'SN':self.which_conv = functools.partial(layers.SNConv2d,kernel_size=3, padding=1,num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,eps=self.SN_eps)self.which_linear = functools.partial(layers.SNLinear,num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,eps=self.SN_eps)else:self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)self.which_linear = nn.Linear# We use a non-spectral-normed embedding here regardless;# For some reason applying SN to G's embedding seems to randomly cripple(削弱) G# 具有SN的Embedding会使G变弱,所以使用无谱正则化的Embeddingself.which_embedding = nn.Embeddingbn_linear = (functools.partial(self.which_linear, bias=False) if self.G_sharedelse self.which_embedding)# 选择bn层self.which_bn = functools.partial(layers.ccbn,which_linear=bn_linear,cross_replica=self.cross_replica,mybn=self.mybn,input_size=(self.shared_dim + self.z_chunk_size if self.G_sharedelse self.n_classes),norm_style=self.norm_style,eps=self.BN_eps)# Prepare model# If not using shared embeddings, self.shared is just a passthroughself.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity())# First linear layer# 第一个线性层,第一个参数为随机噪声被分割成的大小,第二个参数为第一层通道数*最开始的特征图大小self.linear = self.which_linear(self.dim_z // self.num_slots,self.arch['in_channels'][0] * (self.bottom_width **2))

在这部分中,首先选择是否使用具有谱正则化(SN)的线性层、卷积层、embedding层,在这里面从性能角度出发,代码人员对前两者使用了谱正则化操作,对embedding则不使用。接着就是定义了一个线性层和embedding层。
接着就是定义网络结构,代码如下:

    # self.blocks is a doubly-nested list of modules, the outer loop intended# to be over blocks at a given resolution (resblocks and/or self-attention)# while the inner loop is over a given block# 这段注释解释了self.blocks这个变量的结构和用途。self.blocks是一个双层嵌套的模块列表,# 外层循环用于处理给定分辨率的块(例如resblocks和自注意力模块),# 而内层循环则用于处理给定块的模块。# 因此,外层循环按照分辨率逐渐减小,内层循环按照块的顺序逐个处理模块。self.blocks = []for index in range(len(self.arch['out_channels'])):self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],out_channels=self.arch['out_channels'][index],which_conv=self.which_conv,which_bn=self.which_bn,activation=self.activation,upsample=(functools.partial(F.interpolate, scale_factor=2)if self.arch['upsample'][index] else None))]]# If attention on this block, attach it to the endif self.arch['attention'][self.arch['resolution'][index]]:# G_arch的参数‘attention’标志了是否在该block添加注意力print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]# Turn self.blocks into a ModuleList so that it's all properly registered.self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])# output layer: batchnorm-relu-conv.# Consider using a non-spectral conv hereself.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],cross_replica=self.cross_replica,mybn=self.mybn),self.activation,self.which_conv(self.arch['out_channels'][-1], 3))

这部分就是通过循环,将每一层添加到列表中,GBlock是代码人员提前定义好的网络结构,之后将列表转换为ModuleLIst,后面的部分为参数初始化以及优化器部分,不进行赘述。
在forward函数中,对噪声信号进行了处理,处理的维度在之前定义过,代码如下:

  # Note on this forward function: we pass in a y vector which has# already been passed through G.shared to enable easy class-wise# interpolation later. If we passed in the one-hot and then ran it through# G.shared in this forward function, it would be harder to handle.def forward(self, z, y):# If hierarchical, concatenate zs and ysif self.hier:# 如果分层的情况下,把z分成层数+1的数量,第一个保留,剩下的和类别拼接在一起。zs = torch.split(z, self.z_chunk_size, 1)z = zs[0]ys = [torch.cat([y, item], 1) for item in zs[1:]]else:ys = [y] * len(self.blocks)# First linear layer# 第一个z的chunk直接给到线性层h = self.linear(z)# Reshapeh = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)# Loop over blocks# "Loop over blocks" 意味着对生成器(Generator)的每个分辨率分块进行迭代。# 在BigGAN的生成器中,分辨率分块是通过将特征图的空间大小逐渐放大来实现的,每个分辨率级别有一个块。# 因此,在代码实现中,这些块存储在 self.blocks 这个属性中,并通过 for 循环在每个块上进行迭代。# 每个块本身由多个生成器模块(G-Block)组成,这些模块也是通过 for 循环进行迭代。# 通过这种方式,可以在每个分辨率级别逐步生成越来越复杂的图像。for index, blocklist in enumerate(self.blocks):# Second inner loop in case block has multiple layersfor block in blocklist:h = block(h, ys[index])# Apply batchnorm-relu-conv-tanh at outputreturn torch.tanh(self.output_layer(h))

将分割好的噪声信号,除了第一块,剩余块都与类别信息的one-hout向量组合起来,然后输入到刚刚定义好的模型中,产生生成图像。

BigGAN代码解读(gpt3.5帮助)——生成器部分相关推荐

  1. BigGAN代码解读(gpt3.5的帮助)——谱正则化部分

    BigGAN代码解读(gpt4.0的帮助)--谱正则化部分 作者个人记录学习 BigGAN中使用谱归一化对训练过程进行优化,在github中的代码中,使用了自己编写的谱归一化对卷积层.线性层以及Emb ...

  2. python代码解读器_python文章生成器(附源码+讲解)

    移动端建议收藏后在pc端查看 最近在看到网上的[营销号生成器]和[狗屁不通生成器].前者确实是营销号的口吻,但是竟然连模板都是写死的:后者也如其名,的确"狗屁不通".于是结合当前对 ...

  3. 编译原理语义分析代码_Pix2Pix原理分析与代码解读

    原理分析: 图像.视觉中很多问题都涉及到将一副图像转换为另一幅图像(Image-to-Image Translation Problem),这些问题通常都使用特定的方法来解决,不存在一个通用的方法.但 ...

  4. STM32学习心得十八:通用定时器基本原理及相关实验代码解读

    记录一下,方便以后翻阅~ 主要内容: 1) 三种定时器分类及区别: 2) 通用定时器特点: 3) 通用定时器工作过程: 4) 实验一:定时器中断实验补充知识及部代码解读: 6) 实验二:定时器PWM输 ...

  5. deap遗传算法 tirads代码解读

    deap遗传算法 tirads代码解读 写在最前面 Overview 程序概览 参考 deap框架介绍 creator模块 创建适应度类Types 定义适应度策略 创建个体类 Toolbox类 创建种 ...

  6. 类ChatGPT逐行代码解读(2/2):从零起步实现ChatLLaMA和ColossalChat

    本文为<类ChatGPT逐行代码解读>系列的第二篇,上一篇是:如何从零起步实现Transformer.ChatGLM 本文两个模型的特点是加了RLHF 第六部分 LLaMA的RLHF版:C ...

  7. BEGAN-边界均衡生成对抗网络-代码解读

    当前论文代码 首先注意: 不同点: 该论文的输入是噪音,鉴别器和生成器都是哑铃型结构, 相同点: 输出是一张图片,D都是用真实图像去比对. 已知信息 可见,是从main.py开始训练的.测试的时候,只 ...

  8. 代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)

    今天带来一篇代码解读的文章,是2018年BMVC上的一篇暗光增强文章.个人觉得网络比较轻量并且能够取得还不错的效果.废话不多说,直接贴传送门: 文章地址:http://arxiv.org/abs/18 ...

  9. 200行代码解读TDEngine背后的定时器

    作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...

最新文章

  1. chrome出现adobe flash playe 不是最新版本
  2. android cmd adb命令安装和删除apk应用
  3. SDNU 1093.DNA排序(水题)
  4. jvm垃圾回收机制_JVM的垃圾回收机制总结
  5. 月薪2万程序员面试,被HR直面吐槽:毕业生能值这个数?
  6. Vlan与VTP的介绍及工作原理
  7. AngularJs学习笔记--directive
  8. 资料下载地址和我加入的论坛
  9. gan网络损失函数_生成对抗网络的最新研究进展
  10. 电脑tf卡检测不到_tf卡 插入电脑没盘符,但数据恢复软件能检测到异常
  11. h3c交换机重启_华三交换机重启命令(范文篇).doc
  12. vs2013编译ffmpeg之三十一 vidstab
  13. java发布帖子_第一篇发在javaeye的帖子
  14. python修改文件的某一行_简单文件操作python 修改文件指定行的方法
  15. python读取桌面上的文件夹怎么加密_python给文件夹加密 怎么样给python文件加密...
  16. 【贪心 题解】 HDU 5773 The All-purpose Zero
  17. LED点阵屏中“鬼影”现象的分析与解决
  18. Simulink:车辆换挡逻辑回顾_Demo
  19. Octapharma Group公布强劲的2018年业绩
  20. impala 看表结构

热门文章

  1. 鸿蒙窍做何解释,终朝睡在鸿蒙窍 一任时人牛马呼
  2. 黑盒测试用例设计方法详解
  3. Python基础详解(十三):(视频符号化)将视频转换成ASCII符号形式展示出来
  4. 计算方法(五)函数插值
  5. 图片+文案(在图片上)
  6. 30个后台管理系统模板
  7. power 相关:(二)功耗的分析 —— power compiler
  8. 微信群活码的原理及其作用,以及活码怎么使用
  9. 如何使用CubeMX创建STM32F105的程序
  10. 作为软件测试人员,这些常用的性能测试工具你一定要知道