源代码地址:https://github.com/NVlabs/stylegan2-ada-pytorch

  • 这是一篇代码阅读笔记,顾名思义是对代码进行阅读,讲解的笔记。对象是styleGAN2的pytorch版本的代码,在github上有一个开源库。一边笔记方便我回顾,一边也对深度学习初学者有一些阅读理解代码的示例作用吧。代码毕竟是基本功,看到我发的一些代码,评论区问的一些问题实在是代码基础不行。
  • 读一个代码的目的很多样,有些只是想看它的预处理,有些只是想看他的模型,不同目的阅读方式和详略程度也不同。我这里是为了全部读懂给女朋友讲解的,所以会以全部读懂为目的来进行阅读和笔记。

train.py

  • 起始点有很多种,我习惯从主函数开始读。train.py是主程序文件,直接拉到最后可以看到,主函数是从main函数开始的,那就找到main函数,438行开始

train.py/main()

  • 482行看名字就知道是输出日志用的,不重要,可以先跳过
  • 486行,可以看到一些主要参数是在这里设置的,后续要找再来看,这里先跳过
  • 491-498行,设置了一些输出路径,也是跳过,有需要再来看
  • 520-524行,将训练设置保存到一个json文件中
  • 526行到533行是程序的主体部分,在这里启用了多线程,运行subprocess_fn函数,因此下一步就看这个函数。这里稍微展开说一下这个多线程是怎么回事,就是利用了torch.multiprocessing实现了每个GPU分配一个线程,并且多线程之间是用spawn方式创建的。也就是说,你有多少个GPU,就会同时运行多少个subprocess_fn函数,并且spawn方式意味着这些线程都有独立的python解释器程序,资源是复制的,有自己的独立内存而非全部共享内存,而529行是指定了一个临时路径用来给这些线程进行交流,在这个路径下实现需要共享的部分变量。

train.py/subprocess_fn()

  • 367行到380行是torch分布式训练的一些初始化设置
  • 主程序在training_loop模块的training_loop方法中,接下来就跳到这里

training/training_loop.py

training/training_loop.py/training_loop()

  • 这里终于遇到第一个关键点,136行,数据集,调用了construct_class_by_name函数。后面会讲解construct_class_by_name这个函数,这里只需要知道它是个根据输入的参数,返回一个根据参数确定的类的方法即可。最近越来越多的深度学习代码使用这种包装方式,本质上就算想用字符串来调用类,又为了代码统一和简洁,包装得一层接一层,读起来是真的麻烦,而且类名隐藏起来了,甚至无法用vscode的智能追踪来找这里用到的到底是什么类。
  • 这里直接说,training_set根据train.py的107行,是training.dataset.ImageFolderDataset类的对象
  • 以及150行和151行,GD根据train.py的176 177行分别是training.networks.Generator类和training.networks.Discriminator类的对象。而G_ema是G的一个指数移动平均版本,在训练过程中,G的参数会随着step而更新,而G_emaG的迭代过程中各个时期的参数的指数移动平均版本,相比GG_ema的变化更加柔和,这是个常用的技巧。
  • 154到159行的代码加载了模型的参数,可能是预训练的也可能是训练中止接着跑的。
  • 175行 augment_pipe 是train.py 287行 training.augment.AugmentPipe类的对象
  • 180到190进行了多线程训练的模型包装
  • 192到214行定义了训练的几个阶段。这里展开解释一下,GAN的训练策略相比普通模型稍微有些复杂,训练是分阶段的,每个iteration通常要分别训练G和D,并且在训练G的时候,D的参数要固定,训练D的时候,G的参数要固定。这段代码定义了4个阶段:Gmain Greg Dmain Dreg。
  • 195行 loss 根据train.py 187行,是training.loss.StyleGAN2Loss类的对象
  • 199行 opt 根据train.py 185行,是torch.optim.Adam类的对象
  • 216-227行从训练集中采样了一些图片进行可视化(可做debug用),同时也将还没训练的G的输出也做了可视化(可用于检查resume是否加载或者pretrain模型的初始性能)
  • 259行开始训练
  • 260行获取真实图像和对应的label,261-262行归一化图像并划分图像和label到各个GPU
  • 263行-264行生成随机向量作为Generator的输入,并划分到各个GPU
  • 265-267行并从训练集中随机采样label作为条件label,并划分到各个GPU
  • 270行,依次迭代前面提到的4个阶段
  • 278行,把这个阶段需要训练的module设为计算梯度(如训练G的时候,设G的requires_grad为True,而D的为Flase)
  • 281-284行,根据当前所处阶段,为每个GPU分别计算损失。每个阶段的损失介绍损失的时候会展开。
  • 287-294行,根据当前所处阶段,更新待训练的参数(如Gmain和Greg阶段就只更新G的参数),并且把之前设为True的requires_grad改回去,然后进入下一阶段,直到4个阶段全部完成。
  • 296-305行,为G计算指数移动平均,从而更新G_ema的参数
  • 311-315行,根据训练过程的损失,调整数据增强策略的参数,具体在下面介绍数据增强的时候会展开。
  • 318-320行,这里是设置了continue条件,使得每迭代4000(kimg_per_tick*1000)张图片才会运行322行以后的内容一次。实现方式是cur_nimg会一直增加,而tick_start_nimg只有在下面的代码会被设置为cur_nimg,这样一旦运行了一次下面的代码,下次判断小于号就会成立,直到cur_nimg增加了4000使得小于号不成立,然后又会运行一次下面的代码。而done条件是因为,break出循环之前需要运行一次下面的代码,所以设置了当迭代图像数满足图像总数的1000倍的时候,就要退出了,这时候不管是不是每4000次的间隔到了,我都要往下走。
  • 341行设置了另一种退出的方法(代码里似乎没有设abort_fn所以应该这一段代码是没有用到),可以为training_loop传一个有效的abort_fn,使得如果准确率等满足条件返回True,从而不需要跑满1000epoch可以退出。
  • 348-350行为当前iteration生成的图片保存到本地,因为这段代码在322行之后,所以每迭代4000张图片才会生成一次。
  • 353-367行保存了模型参数,同理也是4000张图片才保存一次。
  • 370-379行计算了指标,后续会展开介绍
  • 381-389行和322-338行都说计算运行时间和存储消耗的,就跳过了
  • 391-406行都是输出日志的,跳过
  • 414行是while True循环的唯一退出点。如果运行完成了,就从这里退出循环,结束训练。

dnnlib/util.py/construct_class_by_name()

  • 这个函数只有两行,调用了call_func_by_name函数并以其返回值作为自身的返回值。call_func_by_name函数定义在279行,调用了get_obj_by_name函数,并进一步调用得到的func_obj,以func_obj的返回值作为call_func_by_name的返回值。所以这里其实就是调用了get_obj_by_name函数得到了类,func_obj保存的就是得到的类,然后实例化并返回,所以返回的是类的实例化对象。
  • get_obj_by_name函数在273行,调用了get_module_from_obj_nameget_obj_from_module。有点绕,其实是因为,name是xx.yy.zz的格式,zz才是类名,xx.yy是模块名,所以先调用222行的get_module_from_obj_name从xx.yy.zz中提取出xx.yy和zz,然后再借助get_obj_from_module函数从xx.yy模块中调用zz类。
  • get_module_from_obj_name函数的核心就在231-239行,231-232行其实就是给出根据“.”的位置对字符串划分成两部分的全部可能,所以如果是xx.yy.zz就会被拆成xx和yy.zz或者xx.yy和zz。然后在235到239行,对每种可能性都进行尝试,尝试从xx.yy中import zz,尝试从xx中import yy.zz,因为用的是try,试不出来可以继续,直到试出来,就知道正确的划分方法是什么。
  • get_obj_from_module函数是通过269行的getattr函数来获取模块中的类的。

training/dataset.py

  • ImageFolderDataset类定义在training/dataset.py的154行,是同文件24行Dataset类的子类,一般看__getitem__函数即可。返回值有imagelabelimage根据210-220行的重写,是一个CHW的unit8(0-255)的np array。label是onehot的float32的np array

training/networks.py

  • Generator类定义在training/networks.py的477行。476行是一个装饰器,意思是调用training.networks.Generator的时候,实质上返回的是persistence.persistent_class(Generator),这个装饰器只是为这个类添加了一些辅助功能,不影响接下来的理解,所以先跳过,后续会解释这个装饰器,先接着看模型
  • 模型由两个子模块组成:MappingNetworkSynthesisNetwork

training/networks.py/MappingNetwork()

  • MappingNetwork定义在174行,从初始化函数看起,200行前面定义了一些变量的维度,201行定义了中间全连接层的维度。
  • 204行定义了第一个全连接层,当使用condition label的时候,对这个one hot的condition label进行embed,embed后的特征将和z连在一起作为后续网络的输入。
  • 205-209行定义了网络的主体全连接层
  • 211-212行定义了一个名为w_avg的变量,它不会随着step更新值,但会在一些特殊的时刻进行值的更新和被使用。
  • 这里的FullyConnectedLayer(89行)相比普通的全连接层的区别在于,当lr_multiplier不为1时(208行定义的就不为1,是0.01),这些层的参数的学习率和其它参数的学习率相比会乘以一个lr_multiplier(具体实现其实就是把参数直接乘以一个lr_multiplier再去用,实际效果就等同于学习率乘了一个倍数,因为计算这些参数的梯度的时候也是会因此乘以一个lr_multiplier导致step的时候步长会乘以一个lr_multiplier的)
  • 接下来看forward函数。219和222行都仅仅是检查向量的shape。normalize_2nd_moment函数看21行,其实就是先统计这些特征值的标准差(每个样本单独统计),接着除以标准差进行归一化。其实这么说不太准确,因为没有减去均值,仅仅是先平方,然后平均,然后开根,然后除(rsqrt是1/sqrt)。而20行的装饰器仅仅是使得torch.autograd.profiler.record_function能跟踪到这个函数而已。至于torch.autograd.profiler后续会介绍是个什么东西。
  • 然后在223-224行,c向量送进一个全连接层编码,归一化,然后和归一化后的z向量被concatenate到一起,作为后面全连接层的输入
  • 226-229行就是主体的mapping network,对合并的编码和z向量前向传播经过几层全连接层
  • 231-234行保存全连接网络的输出的移动平均(lerp是根据w_avg_betaw_avgx进行插值的函数)到w_avg变量中
  • 236-239行重复了num_wsx,放在dimension 1上,也就是说现在shape是(B,num_ws,w_dim),具体num_ws是什么下面介绍SynthesisNetwork时会展开说明
  • 242-248行,查完整份代码没有看到哪里有把truncation_psi设为非1的值,所以理论上正常情况这部分代码是不会运行到的。看意思应该是利用w_avgx进行进一步移动平均,这里的移动平均就是对x做了,影响的是x的值,前面的移动平均只是存下来而已,对实际训练过程不会有什么影响。之所以说是截断,是因为当x在训练过程中突然出现异常大或者异常小的值时,这段代码可以通过移动平均限制这些值不要偏离正常范围太远。

training/networks.py/SynthesisNetwork()

  • SynthesisNetwork定义在424行。首先看init函数,440行根据要生成的图片的分辨率,定义了各个block的resolution,依次是2的2,3,4,。。n次方,使得2的n次方最接近要生成的图片的分辨率。441行则定义了各个block的通道数为32768除以block的resolution,但最小是512。
  • 442定义了一个称为fp16_resolution的变量。FP16是一个降低运算量和内存占用的技巧,将32位浮点运算用半精度运算来近似。模型对分辨率最高的num_fp16_res个block进行FP16计算,所以这里是在算开始进行FP16计算的block的resolution。在448行当block的resolution大于等于这里算出来的fp16_resolution时,意味着这个block要进行FP16计算而非全精度的计算。
  • 445-455行定义了SynthesisNetwork的主体由堆叠的几个SynthesisBlock组成。这里还统计了num_ws,后续会解释这个是什么。这里展开解释一下455行的setattr函数,是一种通过字符串变量定义类成员名的方法,比如setattr(a,'hah',1),那么当调用a.hah的时候,返回值会是1,也可以用464行的getattr函数实现调用。
  • SynthesisBlock后面解释,先接着看forward函数,forward函数的输入是MappingNetwork的输出,即是全连接并repeat了num_ws遍后的编码特征,shape为(B,num_ws,w_dim),也就是说对于每个batch,有num_ws个重复的w_dim维的特征。为什么要重复,我的理解是这些副本在后续会被各个模块分别使用,可能是为了避免相互影响?
  • 463-466行将输入的ws特征在dimension 1上拆成多份,每一份分给一个block。也就是说现在每个block的输入是(B,num_conv,w_dim),其实就是每个block分到了num_conv个重复的特征向量。最后一个block会得到 num_conv+num_torgb份(因为只有最后一个block的num_torgb不为0)
  • 468-471行则开始前向传播,每个block的输入是分得的ws和上一个block的输出(x和img),第一个block输入的(x和img)为None。最后一个block输出的img为SynthesisNetwork最终的输出

training/networks.py/SynthesisBlock()

  • SynthesisBlock定义在329行,首先还是看init。354行定义了一个变量,和之前说的一样这个变量是不会被step更新值的。这里用到了upfirdn2d.setup_filter这个方法,参数是resample_filter,为[1,3,3,1],其实是定义了一个torch的tensor,是输入的外积归一化后的结果。也就是[1,3,3,1]和自身的外积,得到一个(4,4)的tensor,并且进行归一化使得元素和为1。所以得到的是一个2D的低通滤波器。
  • 359行定义了一个随机变量,这个变量仅在SynthesisNetwork的第一个block被定义,作为起始的随机变量用来生成图片。
  • 361-364行定义了第一个SynthesisLayer,这个类后续会介绍。这里定义的这一层,第一个block是没有的。
  • 366-368行定义了第二层SynthesisLayer,这一层每个block都有
  • 370-373行根据现有代码是一定会运行的(architecture就是’skip’,除了cfg定义为’cifar’时,cfg默认是’auto’)所以每个block都定义了一个ToRGBLayer,后续会介绍。所以到这里可以看出,除了第一个SynthesisBlock为一个SynthesisLayer加一个ToRGBLayer外,其它的SynthesisBlock为两个SynthesisLayer加一个ToRGBLayer
  • 375-377行代码根据现有architecture的是不会运行的,不介绍。
  • 接着看forward,381行,根据前面,ws是重复的特征向量,所以这里unbind就是把重复的那个维度解出来,再套上iter变成迭代器,那么每次next(w_iter)都会生成一个特征向量,并且每次next生成的特征向量不是同一个,但是内容相同。这个特征向量其实就是MappingNetwork生成的编码特征向量。
  • 382行这里,和前面提到的FP16呼应了。因为styleGAN随着tensor往后传递,生成的tensor是越来越大的,分辨率越来越高,为了节约显存同时提高运算速度,可以把后面几个block的数据类型从float32改成float16,这样节约了一半的显存。
  • 383行定义了向量在内存中的存储格式,连续的存储可以提高某些运算的速度,同时有些运算要求连续存储的tensor才合法,但将存储整理为连续存储会消耗时间。
  • 384-386行定义了fused_modconv bool变量,后续用到的时候再展开,这里只需要知道只有测试的时候才有可能是true。
  • 389-391行,因为第一个block的x是None,所以这里就根据init中自己生成的随机变量来定义x。从这个实现方式可以看出来一个关键信息,即只要模型定义了,随着训练过程,每次迭代,第一个block的输入x都是固定的,不会改变。并且由于const是torch.nn.Parameter,所以加载resume和load 已训练好的参数的时候也是生成和之前训练的时候同一个x。
  • 396-406行即是主体SynthesisLayer的运行,可以看到SynthesisLayer是用来生成x的,而img则是ToRGBLayer生成的。
  • 408-419行是生成img的过程。首先上采样上一个block生成的img(具体方式后续会介绍)。然后ToRGBLayer根据本block的x和ws,生成了残差图y,加到上采样了的img上形成本block输出的img
  • 可以看出,其实网络中x的传递和img是无关的,每个x都只需要根据前一个block的x(第一个block则根据一个固定的x)和MappingNetwork生成的ws,即可生成本block的输出x。而img则是根据前一个block的img和本block的x以及MappingNetwork生成的ws来生成的。

training/networks.py/SynthesisLayer()

  • 这个类是SynthesisLayer的核心组成部分之一,定义在254行,先看init函数,274-284行定义了几个参数,分别是resample_filteraffineweightnoise_const(默认不会用到),noise_strength(定义为0)和bias。注意use_noise在这份代码中默认是True的,所以if是一定会执行的。这几个参数的类型前面都介绍过,而具体作用看forward
  • 首先是290行,所以affine就是个把输入的w变成styles的全连接层。这里的w其实就是SynthesisBlock分给这个SynthesisLayerws的其中一条特征向量。
  • 然后是293行,代码默认noise_mode就是’random‘而且use_noise是True,所以296不会运行,就是294行,重新生成了一个随机的噪声,满足0均值noise_strength标准差的正态分布。值得注意的是,虽然noise_strength在init中定义为0了,所以这里乘出来的noise最初是一个全0的tensor,但noise_strength是可训练参数,会随着训练过程变化,导致noise后面还是会不为0的。
  • 然后298行,up变量在每个SynthesisBlock(除了第一个block)的第一个SynthesisLayer被设为了2,其余都是1,所以flip_weight在每个SynthesisBlock的最后一个SynthesisLayer是True的,其余时候都是False的。
  • 299行是主体,但是包装在了modulated_conv2d函数中,可以先到下面看完再回来这里看。
  • 欢迎回来,302行的gain就是1,self.act_gain根据276行是bias_act.activation_funcs['lrelu'].def_gain,再看torch_utils/ops/bias_act.py的26行,是0.2,所以act_gain就是0.2
  • 303行self.conv_clamp根据train.py182行是256
  • 304行可以先当作是x先加上self.bias再进一个lrelu的激活函数,具体后面会展开,到此SynthesisLayer介绍完成,可以到ToRGBLayer

training/networks.py/modulated_conv2d()

  • 这个函数定义在了同文件的第27行。47-49行为了防止FP16计算溢出,对wstyles向量都除以了其各自的无穷范数(元素最大值)进行归一化,同时w还除以了维度。demodulateSynthesisLayer中都没有设定,所以SynthesisLayermodulated_conv2ddemodulate参数都是True。也就是说54到60行都会运行。
  • 55-56行用到了weightstyles两个tensor来构建w。weight是在SynthesisLayer定义的shape为[out_channels, in_channels, kernel_size, kernel_size]的卷积核,stylesaffine这个全连接层输出的shape为[batch_size, in_channels]的tensor。所以55行把weight扩展了batch size那一维,变成了[1, out_channels, in_channels, kernel_size, kernel_size]的tensor,然后styles reshape成[batch_size, 1, in_channels, 1, 1]的tensor,这两个tensor相乘,得到的是[batch_size, out_channels, in_channels, kernel_size, kernel_size],广播操作在这里的作用其实就相当于,把两个tensor都repeat成[batch_size, out_channels, in_channels, kernel_size, kernel_size],然后element-wise地相乘。直观上理解,这个相乘的作用是两个,一个是为同一个batch的不同sample分配不同的kernel,另一个是根据styles对每个channels的kernel进行rescale。
  • 58行计算得到一个dcoefs,是w的二范数的倒数,对第2 3 4维度分别算的,所以dcoefs的维度是[batch_size, out_channels]
  • 60行又用dcoefs来归一化w,但这段代码和62-72行的代码只有一个会运行。当fused_modconv为True时就运行60行的代码,否则运行62-72行的代码。这里用到了前面跳过的fused_modconv,在386行,这个bool值只有在测试的时候,并且是在前面的FP32层才是True,在FP16层则必须当batch size为1时才为True。所以如果单看训练阶段,60行的代码是不会运行的,只有62-72行会运行。
  • 64行,x[batch_size, in_channels, H, W]的,乘以一个[batch_size, in_channels]的向量,会自动广播,其实就是把stylesrepeat成x的形状,再乘以x
  • 65行调用了conv2d_resample.conv2d_resample来实现weightx的卷积,具体后面会展开说,这里就先暂且当作普通的卷积。
  • 再然后是66-71行的三个判断,SynthesisLayer中调用的modulated_conv2dnoise都是tensor,不是None,所以只会运行67行,fma.fma是自定义的一个函数,其实就是x乘以dcoefsnoise。之所以这么写,而没有直接x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) + noise.to(x.dtype),是因为可以利用torch.addcmul来加速。
  • 74到84行的代码也是在训练阶段是不会运行的,只有fused_modconv为True才会运行。相比63-72行差别在于卷积核从weight改成了w,并且因为乘以了styles,每个样本有一个单独的卷积核,卷积结果也不需要乘以dcoefs(应该是因为w本身就是weight乘以dcoefs的缘故)
  • 所以其实到这里,fused_modconv的作用就很明显了。如果fused_modconv是True,那么进行卷积的xwx不做操作,wweight乘以styles再乘以dcoefs的结果,并且对每个样本有一个卷积核;如果fused_modconv是False,那么进行的是普通卷积,卷积的双方是xweightx要乘以stylesweight不做操作,但卷积结果要乘以dcoefs再输出。
  • 看完这个可以回到SynthesisLayer继续下去

training/networks.py/ToRGBLayer

  • 这个类和SynthesisLayer很类似,而且大多数函数都介绍过了,就不再一一介绍了。只需要注意一点的是,ToRGBLayer没有Noise,所以它调用的modulated_conv2d中不需要加上noise

training/networks.py/Discriminator

  • 这个类定义在673行。709行定义了Discriminator的主体由数个DiscriminatorBlock组成,同时还有一个MappingNetwork和一个DiscriminatorEpilogue。注意,711行定义的变量名,最小是b8(从693行可知),所以和715行的b4并不会冲突。还有就是,和SynthesisBlock相反,前面的DiscriminatorBlock分辨率更大,也就是说变量名b后面的数字更大,但排得更前。
  • 接着看forward函数。可以看到Discriminator的输入由两部分组成,一个是img,一个是条件向量c,根据719-721行,首先是按顺序调用堆叠的DiscriminatorBlock对输入的img进行处理,然后调用一个MappingNetwork对输入的c进行处理,最后用一个DiscriminatorEpilogueDiscriminatorBlock的输出和MappingNetwork的输出作为输入,产生Discriminator最后的输出。

training/networks.py/DiscriminatorBlock

  • 这个类定义在505行。和Genenrator不同,Discriminator的architecture'resnet'(只有cfg'cifar'的时候为'orig')。而其它的,如FP16的设定和resolution都和Generator类似,就不细说了。
  • 533-540行定义了一个迭代器,每次next(trainable_iter)都会返回一个bool值,用来判断当前层是否freeze,freeze多少层由freeze_layers决定,默认参数下是0,也就是没有层会被freeze。如果要设定freeze多少层,可以通过train.py的freezed参数设定,是一个int,指向从Discriminator的第一个block的第一个conv开始的全局层序数,也就是说如果freezed设为5,那么从Discriminator的第一个block的第一个conv开始数,前5个conv都要freeze,后面的全都可训练。
  • 542行的in_channels会在第一个DiscriminatorBlock为0,所以第一个block(分辨率最大的那个)是会运行543-544行的,而后面的block则不会。所以第一个block有一个额外的Conv2dLayer
  • 546-554行定义了3个Conv2dLayer,所以除了第一个block有4个Conv2dLayer外,其它的block都由3个Conv2dLayer组成
  • 接着看forward,566-571行是第一个block独有的代码,其它的block不会运行这一段。这时x就是None,所以这段代码就是把输入的img经过一个Conv2dLayer产生x,同时把img设为None,后续的block再也不需要用到img
  • 575到578行就是把x分了两个支路,一个经过一层Conv2dLayer(在这个类内部会进行一次下采样使得分辨率变为原来的二分之一),一个经过两层Conv2dLayer(在第二层内部会进行一次下采样使得分辨率变为原来的二分之一),得到的两个支路的结果相加可以得到DiscriminatorBlock的输出。

training/networks.py/Conv2dLayer

  • 这个类定义在123行,可以看作是一个卷积+上/下采样+激活函数。核心代码在conv2d_resample.conv2d_resample函数中,其它没什么难点,就偷懒一下不展开了。

training/networks.py/DiscriminatorEpilogue

  • 这个类定义在615行,它的作用是以MappingNetwork的输出cmap和最后一个DiscriminatorBlock的输出x作为输入,产生最终的分类值。637-640定义了4个layer,具体作用看forward函数
  • x依次经过一个MinibatchStdLayer和一个Conv2dLayer,然后展平,送进两个全连接层,得到的结果和cmap进行element-wise的相乘,然后全部求和并归一化得到最终的输出。这个输出同时也是Discriminator的输出,是一个[batch_size, 1]的tensor,就是对图片进行二分类的逻辑值,越高表示越real,越低表示越fake。

training/networks.py/MinibatchStdLayer

  • 这个类定义在589行,不包含任何参数,仅仅是为x增加一些通道,这里初始化的group_size是4,num_channels是1
  • 602行把xreshape成[4, N/4, 1, c, h, w],即把样本分成了N/4组,每组4个样本,然后对特征的各个维度分别计算组内标准差,再对每个位置每个通道的标准差取平均,得到[N/4, 1]的向量,代表了每组的标准差,然后repeat到各个空间位置和组内样本上成为[N,1,H,W]的向量,concatenate到x上成为新的一个通道,所以x变成了[N,C+1,H,W]输出出去。

training/loss.py

training/loss.py/StyleGAN2Loss

  • 整个py文件就这一个类的定义。24行,init函数传进来的参数,根据training_loop.py的184-190行,G_mappingGeneratorMappingNetwork成员,G_synthesisGeneratorSynthesisNetwork成员,DDiscriminator。注意,这个Loss函数是有成员的,有一个初始化为0的pl_mean变量,在后续的accumulate_gradients函数中会对这个向量进行移动平均。
  • 接着看run_G函数,这个函数在accumulate_gradients中被调用。41行的style_mixing_prob是0.9(只有cfg'cifar'时才为0)。
  • 43行声明了一个满足均匀分布的随机整数,范围为1到ws.shape[1].这里的wsGeneratorMappingNetwork的输出。
  • 44行有点复杂,先看where函数内的第一个参数,这是个随机的bool值,有0.9的几率为True,0.1的几率为Flase;第二个参数是刚才提到的1到ws.shape[1]之间的随机整数(均匀分布);第三个参数就是ws.shape[1]。然后看where函数,where函数的三个参数按顺序依次是condition、input、other,意思是,如果condition是True,那么where函数的输出就是input,如果condition是False,那么where函数的输出就是other。所以这一行的意思是,cutoff有0.9的几率会被置为1到ws.shape[1]之间的随机整数,有0.1的几率会被置为ws.shape[1]
  • 45行就是把wscutoff后的那些向量替换成另一个根据随机z和同一c重新生成的向量。这里要回想一下ws的生成过程,其实是一堆重复的特征向量堆叠而成的,也就是说原本ws[:, i]ws[:, j]都是相同的特征向量。
  • 所以40-45行的意思是,根据z生成的ws,有0.9的几率会把其中随机数量的w替换成另一个随机向量z2生成的w,所以此时ws内就有两种w,一种是z生成的w,一种是z2生成的w,并且比例是随机的。
  • 47行根据调整后的ws调用GeneratorSynthesisNetwork生成图片并返回。
  • run_D很简单,就是先对图片augment,然后调用Discriminator判断图像的真假,产生逻辑值。
  • 接着看accumulate_gradients函数,这个函数在每个阶段都会调用一次,所以进来的时候可能是四个阶段的其中一个。
  • 如果是Gmain阶段,则损失函数是71行,对Discriminator预测的logit值取反算损失,softplus是个单调增函数,是平滑的relu,具体见(https://pytorch.org/docs/master/generated/torch.nn.Softplus.html#torch.nn.Softplus),所以backward会使得logit值增加,从而Generator生成的图像更真实,但此时Discriminator的参数要禁用require_grad,不参与更新,这一点在train_loop.py中进行了。
  • 74行gain在这一阶段是1
  • 如果是Greg阶段,79-80行,pl_batch_shrink是2,所以是把batch_size变成了原来的二分之一。这段代码的意义在于它能够使得Greg阶段的batchsize比Gmain阶段的batchsize小。
  • 81行生成了一个很小的正态分布随机噪声,这个噪声是用来产生83行梯度计算的扰动的。83行计算了SynthesisNetwork生成的图像对的MappingNetwork生成的编码的梯度,并乘以了pl_noise作为随机扰动。由于create_graph设为了True,所以这个算出来的梯度项也是可以用来计算损失并backward的,会根据二阶导来更新SynthesisNetwork的参数。这里得到的pl_grads的shape和gen_ws的shape是一样的,都是[batch_size, num_ws, dim_w]
  • 84计算了pl_grads每个sample对不同w的平均向量二范数,得到的shape是[batch_size, ]
  • 85行利用求得的pl_lengthspl_mean进行移动平均,pl_decay是0.01,即是pl_mean = pl_mean + 0.01 * (pl_lengths - pl_mean),这里pl_mean初始值是0,所以随着训练的迭代,pl_mean会是一个保存了历次迭代的pl_lengths的移动平均。
  • 86行把更新后的pl_mean从梯度图中分离出来,以防止被梯度更新改变值,这个变量只是用来保存pl_lengths的移动平均的,不应该被其它过程更新参数。
  • 87行计算了一个pl_penalty变量,这个变量就是Greg阶段的损失了。所以可以看出,Greg阶段主要是惩罚pl_grads的变化。89行pl_weight是2,92行gain在这一阶段是4
  • 95-104行Dmain阶段和Gmain阶段的区别仅在于,logit的符号不再取反,此时训练的是
    Discriminator,需要它输出正确的结果,gain在这一阶段是1
  • 109-131行在Dmain阶段和Dreg阶段都会运行。109行写了个很绕的表达式,其实就是,Dmain阶段name'Dreal',Dreg阶段name'Dr1'
  • 110-114行为Discriminator送入了真实图片,所以118-119行(仅在Dmain阶段运行)计算了对真实图片的discriminator loss,加负号是表示需要Discriminator预测的值越高越好,因为高代表real,低代表fake。
  • 123-128行仅在Dreg阶段运行。124行同样计算了Discriminator的输出对输入的真实图片的梯度,但这里则直接以梯度值的平方和作为损失
  • 131行gain是16
  • 损失函数到此介绍完,值得注意的是Greg阶段和Dreg阶段都取了网络的输出对输入的梯度来计算损失,这看起来与机器学习中以参数W的范数作为损失项的一种正则化方法有点类似,因为网络的输出对输入的梯度可以大体可以视作是网络的参数。

torch_utils/ops

torch_utils/ops/persistence.py/persistent_class

  • 这个函数在torch_utils/persistence.py文件的35行被定义。可以看到99行和130行,这个函数返回的是输入类的一个子类,这个子类为这个输入的类添加了一些功能,包括:保存类的初始化参数,为类添加打包函数(__reduce__方法的功能即为当代码被pickle打包时能输出正确的字符串等,展开说有点离题,具体可以自己去查,这里是暂时不需要了解的细节)

如何对这类代码进行修改以实现自己的idea

  • 这类编程范式常见于微软和谷歌等大厂的完备API库,如mmdet等,他们提供了许多API接口和现有模型,也有API文档,代码层层包装,在接口外对代码进行修改是很困难的事情。但这些代码一般都提供了方便的自定义接口,只要找对方式,在这些库上实现自己的idea是很轻松的事情。
  • 虽然个人对编程范式研究不是很多,但代码看得比较多,这类库个人感觉目的就是最大程度地避免对已有的函数和类的修改。也就是说,你如果想要在这上面实现自己的idea,你应该通过新建的方式而非修改的方式。无论你是想要实现新模型、新训练过程、新损失函数、新数据集、新增强手段等,在深度学习的全环节,这些过程都被包装成一个个的类,而在主函数中通过字符串来索引这些类。因此想要对某个环节进行替换,只需要两个步骤:
    • 在对应的文件里新增该类的定义和实现。如你想把模型换成自己的模型,那就在network.py里面声明和定义自己的类。
    • 在config或者命令行参数中,通过字符串形式指定某环节使用的类的名字。比如你定义了一个叫mynewmodel的类,那么你可能是通过将命令行调用改写为python train.py --model mynewmodel这样的方式实现对自己模型的训练。亦或者,有些代码将参数全部从jsonyaml文件读取,那么你就应该新建自己的json/yaml文件,在其中将模型名改为自己的模型,然后通过python train.py --config myconfig.yaml等方式来运行代码。
  • 总而言之,对这类代码的常规修改方式就应该是通过添加自己的类,然后修改config指向自己的类,来实现。当然自己的类的定义一般不是随心所欲的,通常需要遵循某些规则,比如要定义某些方法,或是在某些方法内提供特定格式的返回值,等等。
  • 通常来说,常规的idea都能通过代码提供的这些常规接口实现,当然也会出现常规接口无法实现的时候,一般就是你创新性地需要利用到某些信息,而这份库恰好没有把这些信息从包装里面传出来,这时候无可避免就得对代码进行修改而非新建。不过通常这些情况非常少,所以建议能不修改代码还是不修改代码,否则会带来很多意想不到的bug和麻烦。

TODO:

torch_utils/ops/conv2d_resample.py/conv2d_resample

torch_utils/ops/bias_act.py/bias_act

training.augment.AugmentPipe

metric_main

misc.assert_shape

misc.profiled_function

upfirdn2d.upsample2d

upfirdn2d.downsample2d

StyleGAN2代码阅读笔记相关推荐

  1. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  2. linux 协议栈 位置,[置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  3. BNN Pytorch代码阅读笔记

    BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...

  4. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练

    系列目录: 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)--数据 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)-- 介绍及分词 菜鸟笔记-DuReader阅读理解基线模 ...

  5. leveldb代码阅读笔记(一)

    leveldb代码阅读笔记 above all leveldb是一个单机的键值存储的内存数据库,其内部使用了 LSM tree 作为底层存储结构,支持多版本数据控制,代码设计巧妙且简洁高效,十分值得作 ...

  6. C++ Primer Plus 6th代码阅读笔记

    C++ Primer Plus 6th代码阅读笔记 第一章没什么代码 第二章代码 carrots.cpp : cout 可以拼接输出,cin.get()接受输入 convert.cpp 函数原型放在主 ...

  7. [原创]fetchmail代码阅读笔记---ESMTP的认证方式

    fetchmail代码阅读笔记---ESMTP的认证方式 作者: 默难 ( monnand@gmail.com ) 0    引言 fetchmail是Eric S. Raymond组织编写的一款全功 ...

  8. CNN去马赛克代码阅读笔记

    有的博客链接是之前几周写好的草稿,最近整理的时候才发布的 CNN去马赛克论文及代码下载地址 有torch,minimal torch和caffe三种版本 关于minimal torch版所做的努力,以 ...

  9. ORB-SLAM2代码阅读笔记(五):Tracking线程3——Track函数中单目相机初始化

    Table of Contents 1.特征点匹配相关理论简介 2.ORB-SLAM2中特征匹配代码分析 (1)Tracking线程中的状态机 (2)单目相机初始化函数MonocularInitial ...

最新文章

  1. EOS Chain/Wallet RPC API的PHP开发包
  2. java speex回声消除_speex 回声消除的用法
  3. c#判断输入textbox是否为数字
  4. vu项目中按F5刷新element菜单没有根据路由匹配菜单解决办法
  5. Silverlight与WCF之间的通信(4)silverlight以net.tcp方式调用console上寄宿的wcf服务
  6. php 主进程子进程,PHP中的子进程的任何等价物?
  7. switch语句判断范围_MQL5从入门到精通【第四章】(一)条件判断语句
  8. Windows操作系统安全配置缺陷自动检测技术
  9. matlab 二元函数的极限,利用MATLAB软件求解一元和二元函数的极值
  10. 系统动力学模型_049,系统动力学模型,医生是怎么确定用药剂量的
  11. Photoshop栅格化图层到底什么意思,什么时候该用栅格化涂层
  12. 用python来开发webgame服务端系列
  13. html设置表格列宽百 分比,WPS解决实现单页显示 高分辨率显示器百分之百比例下双页改单页方法...
  14. centos7安装netspeeder教程
  15. 删掉wps后台烂进程
  16. IPFS何时落地应用?FIL价值破千?
  17. android. 长图加载
  18. 分享一下 软件测试面试历程和套路,真的很实在
  19. [c++简单小游戏]东搞西搞第二弹——谷歌chrome小恐龙升级版(啊哈)
  20. 全志XR806芯片 getsockopt、setsockopt失败如何解决?

热门文章

  1. 共享wifi项目怎么样
  2. 《童梦奇缘-梦幻般的羁绊》第八章-殇变
  3. Thymeleaf contextPath
  4. DAY2:使用 docker-compose 搭建 wordpress
  5. leetcode每日一题 521. 最长特殊序列 Ⅰ 脑筋急转弯
  6. 自动化运维工具-----Ansible playbook详解
  7. 将PPT中的文本提取到word文档
  8. 赛马网基本算法之一 (股神问题)
  9. mysql 时间相减取秒_MySQL两个日期字段相减得到秒的方法
  10. 打印的时候计算机出现蓝屏,在电脑打印时电脑总是会蓝屏或重启怎么处理