简记GAN网络的loss
《简记GAN loss的理解》
GAN 是一种思想,刚接触的时候极为震撼,后来通过GAN思想也做过模型的优化,写过一篇专利。最近在用 GAN 生成数据,顺手写一写对GAN loss的理解。
Key Words:GAN、SegAN
Beijing, 2021.01
作者:RaySue
Agile Pioneer
文章目录
- GAN loss
- 通过代码理解 GAN
- SegAN
- 参考
GAN loss
GAN 网络一般形式的 loss 如下:
minGmaxDEx∼Pdata[log(D(x))]+Ez∼Pz[log(1−D(G(z)))]\min_{G}\max_{D} \mathbb{E}_{x \sim P_{data}} [log (D(x))] + \mathbb{E}_{z \sim P_{z}} [log (1 - D(G(z)))] GminDmaxEx∼Pdata[log(D(x))]+Ez∼Pz[log(1−D(G(z)))]
其中:D(x) 是判别器,G(z) 是生成器。判别器的目的是最大化 loss,生成器是为了最小化 loss。 GAN 是一种思想,GAN 可以做很多任务,但万变不离其宗,就是最大,最小化损失,以达到生成对抗的目的。
- 为什么 D(x) 是期望其最大化 ?
D(x) 对应的损失有两项,第一项 Ex∼Pdata[log(D(x))]\mathbb{E}_{x \sim P_{data}} [log (D(x))]Ex∼Pdata[log(D(x))],第二项 Ez∼Pz[log(1−D(G(z)))]\mathbb{E}_{z \sim P_{z}} [log (1 - D(G(z)))]Ez∼Pz[log(1−D(G(z)))],并且D(x) ∈\in∈ [0, 1],而log(x)log(x)log(x) 在[0, 1] 上单调递增,我们期望真实的数据D(x)D(x)D(x)是趋近于1 的,所以第一项期望增大。第二项中 D(G(z))D(G(z))D(G(z))是生成数据期望其趋向于 0,所以整体也是递增的,所以D(x) 对应的整体期望是要更大的,即 maxD\max_{D}maxD 。
- G(x) 最小化的 loss 是什么?
G(x) 把生成数据的标签设置为1,并和输入生成数据输入D(x)D(x)D(x)得到判别的结果,进行训练,来迷惑D(x)D(x)D(x),这样如果生成数据太假就会产生很大的 loss ,注意,这里的loss仅用于优化生成器,所以整个网络才会work。
G(x) 对应的损失为第二项 Ez∼Pz[log(1−D(G(z)))]\mathbb{E}_{z \sim P_{z}} [log (1 - D(G(z)))]Ez∼Pz[log(1−D(G(z)))] ,我们想让生成数据更接近 真实数据,所以D(G(z))D(G(z))D(G(z))就趋向于1,整体就期望更小,即minG\min_GminG。
整体的思想就是:
判别器和生成器各司其职,判别器通过先验知识,知道哪部分数据来自生成器,哪部分是真实数据,所以尽管生成的再像真实数据,也会当成负样本来学习,让判别器越来越强
而生成器将其生成的结果送入判别器,并将判别器返回的预测结果利用真实数据标签来训练,以此来迷惑判别器,从而产生更接近于真实数据的结果
通过代码理解 GAN
先训练 D(x) 再训练 G(x) 交替进行
判别器 D(x) 和生成器 G(x) 利用两个优化器,每个优化器单独负责优化对应的网络参数。
D(x) 训练的过程:将真实数据的标签定义为 1,将生成数据定义标签为 0
G(x) 训练的过程:将生成的数据的标签定义为1,和经过 D(x) 的判别结果计算损失来更新参数。
criterion = nn.BCELoss() # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)D.zero_grad()
# 1A: Train D on real
d_real_data = Variable(d_sampler(d_input_size))
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1]))) # ones = true
d_real_error.backward() # compute/store gradients, but don't change params# 1B: Train D on fake
d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1]))) # zeros = fake
d_fake_error.backward()
d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1]))) # Train G to pretend it's genuineg_error.backward()
g_optimizer.step() # Only optimizes G's parameters
SegAN
投稿于 Neuroinformatics (2018) 的一篇论文
通过GAN加持语义分割,能够让学习到的结果更加精细
SegAN 和 GAN 一样的两部分 loss,
通过预测的 01 mask(生成数据)和原图取交然后输入到 D(x) 中得到一个 res_mask
通过真实的 01 mask(真实数据) 和原图取交然后输入到 D(x) 中得到 target_mask
训练D(x) 的时候直接取 -mae (res_mask, target_mask) 让误差向大了学,判别能力越来越强,越来越挑剔
训练G(x) 的时候则 mae(res_mask, target_mask) 让误差减小,让学到的更像target
损失函数:
minθSmaxθCL(θS,θC)=1N∑n=1Nlmae(fC(xn⋅S(xn)),fC(xn⋅yn))\min_{\theta_S}\max_{\theta_C} L(\theta_S, \theta_C) = \frac{1}{N}\sum_{n=1}^{N}l_{mae}(f_C(x_n \cdot S(x_n)), f_C(x_n \cdot y_n))θSminθCmaxL(θS,θC)=N1n=1∑Nlmae(fC(xn⋅S(xn)),fC(xn⋅yn))
- xnx_nxn 原图
- yny_nyn 预测结果
参考
https://www.cnblogs.com/walter-xh/p/10051634.html
https://zhuanlan.zhihu.com/p/78822561?from_voters_page=true
https://blog.csdn.net/sunyao_123/article/details/80288398
简记GAN网络的loss相关推荐
- GAN网络概述及LOSS函数详解
Generative Adversarial Nets 上周周报已经写了这篇论文,本周相对GAN网络的LOSS进行进一步的学习和研究. GAN网络: 条件:G依照真实图像生成大量的类似图像,D是辨别输 ...
- GAN背后的理论依据,以及为什么只使用GAN网络容易产生
花了一下午研究的文章,解答了我关于GAN网络的很多疑问,内容的理论水平很高,只能尽量理解,但真的是一篇非常好的文章转自http://www.dataguru.cn/article-10570-1.ht ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 不服就GAN:GAN网络生成 cifar10 的图片实例(keras 详细实现步骤),GAN 的训练的各种技巧总结,GAN的注意事项和大坑汇总
GAN 的调参技巧总结 生成器的最后一层不使用 sigmoid,使用 tanh 代替 使用噪声作为生成器的输入时,生成噪声的步骤使用 正态分布 的采样来产生,而不使用均匀分布 训练 discrimin ...
- GAN网络详解(从零入门)
从一个小白的方式理解GAN网络(生成对抗网络),可以认为是一个造假机器,造出来的东西跟真的一样,下面开始讲如何造假:(主要讲解GAN代码,代码很简单) 我们首先以造小狗的假图片为例. 首先需要一个生成 ...
- 论文阅读——TR-GAN: Topology Ranking GAN with Triplet Loss for Retinal Artery/Vein Classification
论文阅读--TR-GAN: Topology Ranking GAN with Triplet Loss for Retinal Artery/Vein Classification 基于对抗神经网络 ...
- GAN网络的模型坍塌和不稳定的分析
众所周知,GAN异常强大,同时也非常难以训练.主要有以下亮点原因: 模型坍塌(mode collapse) 难以收敛和训练不稳定(convergence and instability) GAN网络的 ...
- GAN网络生成手写体数字图片
Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的. 目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接 ...
- 机器学习——从0开始构建自己的GAN网络
目录 一 前言 二 生成式对抗网络GAN 三 GAN的训练思路 四 数据集--Chinese MNIST 五 代码--python 1.文件展示 2.代码(一) --数据预处理 3.代码(二) --生 ...
最新文章
- Types of intraclass correlation coefficience (ICC)
- 【腾讯圣诞晚会TEG节目】这里的黎明静悄悄
- 传输层的TCP和UDP
- 是不是一个东西_迷你世界:一个金币就能买到稀有武器?这么良心的售货机在哪领...
- 卓金武——从数学建模到MATLAB
- powerpoint文字教程
- MUI 拍照和从系统相册选择图片上传
- Java—这把线程池ThreadPoolExecutor操作,你学会了吗?
- 将字符串转换为全角或半角
- C++11 处理时间和日期的处理,以及chrono库介绍
- Linux 安装Redis单机版(使用Mac远程访问)
- 4*4矩阵键盘原理分析以及代码展示
- md5加密算法~Java语言实现
- 黑客进行攻击中最重要的环节“信息收集”
- T156基于51单片机LCD12864指针时钟Proteus设计、keil程序、c语言、源码、ds1302,电子时钟,62256
- 北京等保测评机构项目测评收费价格标准参考
- HTML+CSS网页设计期末课程大作——XXXXX (X页) HTML5网页设计成品_学生DW静态网页设计_web课程设计网页制作
- emoji表情符号编码大全
- 尝试创建windows XP最长的路径名
- 免费短信九成暗藏陷阱
热门文章
- 字符串大写字符串转小写js_C ++字符串大写和小写
- 熊猫分发_熊猫cut()函数示例
- jsf tree组件_JSF UI组件标签属性示例教程
- 熊猫read_csv()–将CSV文件读取到DataFrame
- C语言和C++的区别是什么?到底学哪种好
- 认识安全测试之SQL注入
- 【iOS开发】Alamofire框架的使用二 高级用法
- You (root) are not allowed to access to (crontab) because of pam configuration
- Android无线安全测试工具-WiFinSpect
- MOQL—筛选器(Selector)(二)