《简记GAN loss的理解》

  GAN 是一种思想,刚接触的时候极为震撼,后来通过GAN思想也做过模型的优化,写过一篇专利。最近在用 GAN 生成数据,顺手写一写对GAN loss的理解。

Key Words:GAN、SegAN

Beijing, 2021.01

作者:RaySue

Agile Pioneer  

文章目录

  • GAN loss
  • 通过代码理解 GAN
  • SegAN
  • 参考

GAN loss

GAN 网络一般形式的 loss 如下:

min⁡Gmax⁡DEx∼Pdata[log(D(x))]+Ez∼Pz[log(1−D(G(z)))]\min_{G}\max_{D} \mathbb{E}_{x \sim P_{data}} [log (D(x))] + \mathbb{E}_{z \sim P_{z}} [log (1 - D(G(z)))] Gmin​Dmax​Ex∼Pdata​​[log(D(x))]+Ez∼Pz​​[log(1−D(G(z)))]

其中:D(x) 是判别器,G(z) 是生成器。判别器的目的是最大化 loss,生成器是为了最小化 loss。 GAN 是一种思想,GAN 可以做很多任务,但万变不离其宗,就是最大,最小化损失,以达到生成对抗的目的。

  • 为什么 D(x) 是期望其最大化 ?

D(x) 对应的损失有两项,第一项 Ex∼Pdata[log(D(x))]\mathbb{E}_{x \sim P_{data}} [log (D(x))]Ex∼Pdata​​[log(D(x))],第二项 Ez∼Pz[log(1−D(G(z)))]\mathbb{E}_{z \sim P_{z}} [log (1 - D(G(z)))]Ez∼Pz​​[log(1−D(G(z)))],并且D(x) ∈\in∈ [0, 1],而log(x)log(x)log(x) 在[0, 1] 上单调递增,我们期望真实的数据D(x)D(x)D(x)是趋近于1 的,所以第一项期望增大。第二项中 D(G(z))D(G(z))D(G(z))是生成数据期望其趋向于 0,所以整体也是递增的,所以D(x) 对应的整体期望是要更大的,即 max⁡D\max_{D}maxD​ 。

  • G(x) 最小化的 loss 是什么?

G(x) 把生成数据的标签设置为1,并和输入生成数据输入D(x)D(x)D(x)得到判别的结果,进行训练,来迷惑D(x)D(x)D(x),这样如果生成数据太假就会产生很大的 loss ,注意,这里的loss仅用于优化生成器,所以整个网络才会work。

G(x) 对应的损失为第二项 Ez∼Pz[log(1−D(G(z)))]\mathbb{E}_{z \sim P_{z}} [log (1 - D(G(z)))]Ez∼Pz​​[log(1−D(G(z)))] ,我们想让生成数据更接近 真实数据,所以D(G(z))D(G(z))D(G(z))就趋向于1,整体就期望更小,即min⁡G\min_GminG​。

整体的思想就是:

  1. 判别器和生成器各司其职,判别器通过先验知识,知道哪部分数据来自生成器,哪部分是真实数据,所以尽管生成的再像真实数据,也会当成负样本来学习,让判别器越来越强

  2. 而生成器将其生成的结果送入判别器,并将判别器返回的预测结果利用真实数据标签来训练,以此来迷惑判别器,从而产生更接近于真实数据的结果


通过代码理解 GAN

  • 先训练 D(x) 再训练 G(x) 交替进行

  • 判别器 D(x) 和生成器 G(x) 利用两个优化器,每个优化器单独负责优化对应的网络参数。

  • D(x) 训练的过程:将真实数据的标签定义为 1,将生成数据定义标签为 0

  • G(x) 训练的过程:将生成的数据的标签定义为1,和经过 D(x) 的判别结果计算损失来更新参数。

criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)D.zero_grad()
#  1A: Train D on real
d_real_data = Variable(d_sampler(d_input_size))
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = true
d_real_error.backward() # compute/store gradients, but don't change params#  1B: Train D on fake
d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
d_fake_error.backward()
d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuineg_error.backward()
g_optimizer.step()  # Only optimizes G's parameters

SegAN

投稿于 Neuroinformatics (2018) 的一篇论文

  • 通过GAN加持语义分割,能够让学习到的结果更加精细

  • SegAN 和 GAN 一样的两部分 loss,

  • 通过预测的 01 mask(生成数据)和原图取交然后输入到 D(x) 中得到一个 res_mask

  • 通过真实的 01 mask(真实数据) 和原图取交然后输入到 D(x) 中得到 target_mask

  • 训练D(x) 的时候直接取 -mae (res_mask, target_mask) 让误差向大了学,判别能力越来越强,越来越挑剔

  • 训练G(x) 的时候则 mae(res_mask, target_mask) 让误差减小,让学到的更像target

损失函数:

min⁡θSmax⁡θCL(θS,θC)=1N∑n=1Nlmae(fC(xn⋅S(xn)),fC(xn⋅yn))\min_{\theta_S}\max_{\theta_C} L(\theta_S, \theta_C) = \frac{1}{N}\sum_{n=1}^{N}l_{mae}(f_C(x_n \cdot S(x_n)), f_C(x_n \cdot y_n))θS​min​θC​max​L(θS​,θC​)=N1​n=1∑N​lmae​(fC​(xn​⋅S(xn​)),fC​(xn​⋅yn​))

  • xnx_nxn​ 原图
  • yny_nyn​ 预测结果

参考

https://www.cnblogs.com/walter-xh/p/10051634.html

https://zhuanlan.zhihu.com/p/78822561?from_voters_page=true

https://blog.csdn.net/sunyao_123/article/details/80288398

简记GAN网络的loss相关推荐

  1. GAN网络概述及LOSS函数详解

    Generative Adversarial Nets 上周周报已经写了这篇论文,本周相对GAN网络的LOSS进行进一步的学习和研究. GAN网络: 条件:G依照真实图像生成大量的类似图像,D是辨别输 ...

  2. GAN背后的理论依据,以及为什么只使用GAN网络容易产生

    花了一下午研究的文章,解答了我关于GAN网络的很多疑问,内容的理论水平很高,只能尽量理解,但真的是一篇非常好的文章转自http://www.dataguru.cn/article-10570-1.ht ...

  3. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  4. 不服就GAN:GAN网络生成 cifar10 的图片实例(keras 详细实现步骤),GAN 的训练的各种技巧总结,GAN的注意事项和大坑汇总

    GAN 的调参技巧总结 生成器的最后一层不使用 sigmoid,使用 tanh 代替 使用噪声作为生成器的输入时,生成噪声的步骤使用 正态分布 的采样来产生,而不使用均匀分布 训练 discrimin ...

  5. GAN网络详解(从零入门)

    从一个小白的方式理解GAN网络(生成对抗网络),可以认为是一个造假机器,造出来的东西跟真的一样,下面开始讲如何造假:(主要讲解GAN代码,代码很简单) 我们首先以造小狗的假图片为例. 首先需要一个生成 ...

  6. 论文阅读——TR-GAN: Topology Ranking GAN with Triplet Loss for Retinal Artery/Vein Classification

    论文阅读--TR-GAN: Topology Ranking GAN with Triplet Loss for Retinal Artery/Vein Classification 基于对抗神经网络 ...

  7. GAN网络的模型坍塌和不稳定的分析

    众所周知,GAN异常强大,同时也非常难以训练.主要有以下亮点原因: 模型坍塌(mode collapse) 难以收敛和训练不稳定(convergence and instability) GAN网络的 ...

  8. GAN网络生成手写体数字图片

    Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...

  9. 机器学习——从0开始构建自己的GAN网络

    目录 一 前言 二 生成式对抗网络GAN 三 GAN的训练思路 四 数据集--Chinese MNIST 五 代码--python 1.文件展示 2.代码(一) --数据预处理 3.代码(二) --生 ...

最新文章

  1. Types of intraclass correlation coefficience (ICC)
  2. 【腾讯圣诞晚会TEG节目】这里的黎明静悄悄
  3. 传输层的TCP和UDP
  4. 是不是一个东西_迷你世界:一个金币就能买到稀有武器?这么良心的售货机在哪领...
  5. 卓金武——从数学建模到MATLAB
  6. powerpoint文字教程
  7. MUI 拍照和从系统相册选择图片上传
  8. Java—这把线程池ThreadPoolExecutor操作,你学会了吗?
  9. 将字符串转换为全角或半角
  10. C++11 处理时间和日期的处理,以及chrono库介绍
  11. Linux 安装Redis单机版(使用Mac远程访问)
  12. 4*4矩阵键盘原理分析以及代码展示
  13. md5加密算法~Java语言实现
  14. 黑客进行攻击中最重要的环节“信息收集”
  15. T156基于51单片机LCD12864指针时钟Proteus设计、keil程序、c语言、源码、ds1302,电子时钟,62256
  16. 北京等保测评机构项目测评收费价格标准参考
  17. HTML+CSS网页设计期末课程大作——XXXXX (X页) HTML5网页设计成品_学生DW静态网页设计_web课程设计网页制作
  18. emoji表情符号编码大全
  19. 尝试创建windows XP最长的路径名
  20. 免费短信九成暗藏陷阱

热门文章

  1. 字符串大写字符串转小写js_C ++字符串大写和小写
  2. 熊猫分发_熊猫cut()函数示例
  3. jsf tree组件_JSF UI组件标签属性示例教程
  4. 熊猫read_csv()–将CSV文件读取到DataFrame
  5. C语言和C++的区别是什么?到底学哪种好
  6. 认识安全测试之SQL注入
  7. 【iOS开发】Alamofire框架的使用二 高级用法
  8. You (root) are not allowed to access to (crontab) because of pam configuration
  9. Android无线安全测试工具-WiFinSpect
  10. MOQL—筛选器(Selector)(二)