文章目录

  • GAN的原理图:
  • GAN的原版算法描述:
  • pytorch实现
    • 构建generator和discriminator:
    • 生成fake data:
    • 生成real data:
    • 定义训练D的loss,定义训练G的loss, 实际就是forward pass:
    • 优化过程, 实际就是backward pass:
      • 为了实现fixedGtrainD,fixedDtrainG,我们设计优化器更新指定区域的参数:
      • fixedG trainD:
      • fixedDtrainG:
    • 完整版训练过程:

GAN的原理图:

GAN的原版算法描述:

pytorch实现

构建generator和discriminator:

G = nn.Sequential(                      # Generatornn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)D = nn.Sequential(                      # Discriminatornn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)

生成fake data:

    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideasG_paintings = G(G_ideas)                    # fake painting from G (random ideas)

生成real data:


def artist_works():     # painting from the famous artist (real target)a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]paintings = a * np.power(PAINT_POINTS, 2) + (a-1)paintings = torch.from_numpy(paintings).float()return paintingsartist_paintings = artist_works()           # real painting from artist

定义训练D的loss,定义训练G的loss, 实际就是forward pass:

这个loss就相当于把G和D连接起来了,形成通路了,这里实际上体现了pytorch动态图的思想。

    prob_artist0 = D(artist_paintings)          # D try to increase this probprob_artist1 = D(G_paintings)               # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))

优化过程, 实际就是backward pass:

为了实现fixedGtrainD,fixedDtrainG,我们设计优化器更新指定区域的参数:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

fixedG trainD:

opt_D.zero_grad():需要先初始化opt_D,避免前面的数据影响当前更新。
D_loss.backward(retain_graph=True) :计算整个graph梯度,retain_graph=True,需要保持计算图,啥意思?pytorch默认计算一次backward就释放当前graph,释放了就是你必须从头开始走forward pass ,而这里我们需要重新走一遍原图的D部分。
opt_D.step():根据梯度更新指定区域的参数

    opt_D.zero_grad()D_loss.backward(retain_graph=True)      # reusing computational graphopt_D.step()

fixedDtrainG:

G_loss.backward():G_loss = torch.mean(torch.log(1. - prob_artist1)),prob_artist1 = D(G_paintings) 可以看出我们需要重新走一遍D(G不需要走),这个是在原来graph上操作的,这就是为什么需要retain graph的原因

    opt_G.zero_grad()G_loss.backward()opt_G.step()

完整版训练过程:

def artist_works():     # painting from the famous artist (real target)a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]paintings = a * np.power(PAINT_POINTS, 2) + (a-1)paintings = torch.from_numpy(paintings).float()return paintingsG = nn.Sequential(                      # Generatornn.Linear(N_IDEAS, 128),            # random ideas (could from normal distribution)nn.ReLU(),nn.Linear(128, ART_COMPONENTS),     # making a painting from these random ideas
)D = nn.Sequential(                      # Discriminatornn.Linear(ART_COMPONENTS, 128),     # receive art work either from the famous artist or a newbie like Gnn.ReLU(),nn.Linear(128, 1),nn.Sigmoid(),                       # tell the probability that the art work is made by artist
)opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)plt.ion()   # something about continuous plottingfor step in range(10000):artist_paintings = artist_works()           # real painting from artistG_ideas = torch.randn(BATCH_SIZE, N_IDEAS)  # random ideasG_paintings = G(G_ideas)                    # fake painting from G (random ideas)prob_artist0 = D(artist_paintings)          # D try to increase this probprob_artist1 = D(G_paintings)               # D try to reduce this probD_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))G_loss = torch.mean(torch.log(1. - prob_artist1))opt_D.zero_grad()D_loss.backward(retain_graph=True)      # reusing computational graphopt_D.step()opt_G.zero_grad()G_loss.backward()opt_G.step()

深度学习总结:GAN,原理,算法描述,pytoch实现相关推荐

  1. 深度学习Anchor Boxes原理与实战技术

    深度学习Anchor Boxes原理与实战技术 目标检测算法通常对输入图像中的大量区域进行采样,判断这些区域是否包含感兴趣的目标,并调整这些区域的边缘,以便更准确地预测目标的地面真实边界框.不同的模型 ...

  2. 深度学习attention原理_深度学习Anchor Boxes原理与实战技术

    深度学习Anchor Boxes原理与实战技术 目标检测算法通常对输入图像中的大量区域进行采样,判断这些区域是否包含感兴趣的目标,并调整这些区域的边缘,以便更准确地预测目标的地面真实边界框.不同的模型 ...

  3. 深度学习word2vec笔记之算法篇

    本文转载自<深度学习word2vec笔记之算法篇>对排版和内容作了部分调整,感谢大佬分享. PDF版本关注微信公众号:[终南樵],回复:[word2vec基础]获取 1. 声明 该博文是G ...

  4. 学习笔记之——基于深度学习的目标检测算法

    国庆假期闲来无事~又正好打算入门基于深度学习的视觉检测领域,就利用这个时间来写一份学习的博文~本博文主要是本人的学习笔记与调研报告(不涉及商业用途),博文的部分来自我团队的几位成员的调研报告(由于隐私 ...

  5. 基于深度学习的人脸识别算法

    基于深度学习的人脸识别算法 简介 Contrastive Loss Triplet Loss Center Loss A-Softmax Loss 参考文献: 简介 我们经常能从电影中看到各种神奇的人 ...

  6. 病虫害模型算法_基于深度学习的目标检测算法综述

    sigai 基于深度学习的目标检测算法综述 导言 目标检测的任务是找出图像中所有感兴趣的目标(物体),确定它们的位置和大小,是机器视觉领域的核心问题之一.由于各类物体有不同的外观,形状,姿态,加上成像 ...

  7. 综述 | 基于深度学习的目标检测算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:计算机视觉life 导读:目标检测(Object Det ...

  8. 深度学习之反向传播算法

    深度学习之反向传播算法 直观理解反向传播 反向传播算法是用来求那个复杂到爆的梯度的. 上一集中提到一点,13000维的梯度向量是难以想象的.换个思路,梯度向量每一项的大小,是在说代价函数对每个参数有多 ...

  9. 2.2)深度学习笔记:优化算法

    目录 1)Mini-batch gradient descent(重点) 2)Understanding mini-batch gradient descent 3)Exponentially wei ...

  10. 基于深度学习的场景分割算法研究综述

    基于深度学习的场景分割算法研究综述 人工智能技术与咨询 来自<计算机研究与发展> ,作者张 蕊等 摘 要 场景分割的目标是判断场景图像中每个像素的类别.场景分割是计算机视觉领域重要的基本问 ...

最新文章

  1. Powershell管理系列(八)Exchange 2013通讯组管理
  2. 搭建SVN版本控制服务器
  3. DevExpress控件之GridControl、GridView
  4. 计算机蠕虫的存在形式,计算机蠕虫
  5. HDU-1540 Tunnel Warfare 线段树最大连续区间 或 STL巧解
  6. python中依次输出字符_Python如何输出某关键字符并输出完整字符串
  7. Android6.0权限适配及兼容库的实现
  8. ubuntu 双击打不开软件或者创建的快捷方式
  9. 软件项目测试报价单,某软件项目报价单
  10. 利用计算机进行频数分布表制作,实验三 利用Excel软件作频数分布表和统计图表...
  11. Cisco ❀ QinQ技术与VXLAN技术的区别
  12. win10user文件夹迁移_Win10纯净版下迁移用户文件的技巧
  13. 记一次从 git pull 出现 Untracked FilesPervent Merge
  14. 《C Primer Plus第六版》第六章复习题目和编程练习题的答案
  15. 智慧社区+物联网解决方案
  16. 计算机培训考试内容,计算机等级考试的科目和内容解析
  17. sja1000 CAN驱动学习、调试记录(基于PeliCan Mode)
  18. 》技术应用:大数据产品体系
  19. 长调用与短调用 调用门
  20. Electron 使用Pepper Flash插件

热门文章

  1. oracle 11g 清除 trc后缀文件,请教一个跟踪文件的问题。11g 很多trc文件。。
  2. java8 collect 类型转换_Java 8 新特性 Stream类的collect方法
  3. Windows环境配置Anaconda+cuda+cuDNN+pytorch+jupyter notebook
  4. python中map函数字典映射_python Chainmap函数(19)
  5. 泛型方法的定义和使用_泛型( Generic )
  6. 怎样才能让Android平板不卡,如何让你的安卓平板从获新生
  7. php的优化模块,php memcache模块优化配置详解
  8. bagging和时间序列预测_时间序列的LSTM模型预测——基于Keras
  9. php js 正则表达式,【PHP】用正则表达式过滤js代码(注意这个分析过程)
  10. java assetmanager_AssetManager asset的使用