WGAN-GP代码注释

本文链接:https://blog.csdn.net/qq_20943513/article/details/73129308
代码地址:https://github.com/bojone/gan/

对于这个基于tensorflow实现的代码,我对其进行了简单的注释。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
from scipy import misc,ndimage#读入本地的MNIST数据集,该函数为mnist专用
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)batch_size = 100 #每个batch的大小
width,height = 28,28  #每张图片包含28*28个像素点
mnist_dim = width*height #用一个数字数组表示一张图,那么这个数组展开成向量的长度就是28*28=784
random_dim = 10 #每张图表示一个数字,从0到9
epochs = 1000000  #共100万轮def my_init(size): #从[-0.05,0.05]的均匀分布中采样得到维度是size的输出return tf.random_uniform(size, -0.05, 0.05)#判别器相关参数设定
D_W1 = tf.Variable(my_init([mnist_dim, 128])) #784*128
D_b1 = tf.Variable(tf.zeros([128])) #长度为128的一维张量,值均为0
D_W2 = tf.Variable(my_init([128, 32]))
D_b2 = tf.Variable(tf.zeros([32]))
D_W3 = tf.Variable(my_init([32, 1]))
D_b3 = tf.Variable(tf.zeros([1]))
D_variables = [D_W1, D_b1, D_W2, D_b2, D_W3, D_b3]#生成器相关参数设定
G_W1 = tf.Variable(my_init([random_dim, 32]))
G_b1 = tf.Variable(tf.zeros([32]))
G_W2 = tf.Variable(my_init([32, 128]))
G_b2 = tf.Variable(tf.zeros([128]))
G_W3 = tf.Variable(my_init([128, mnist_dim]))
G_b3 = tf.Variable(tf.zeros([mnist_dim]))
G_variables = [G_W1, G_b1, G_W2, G_b2, G_W3, G_b3]#判别器网络结构
def D(X):X = tf.nn.relu(tf.matmul(X, D_W1) + D_b1) #X的维度是100*784,D_W1维度是784*128,得到结果维度为100*128X = tf.nn.relu(tf.matmul(X, D_W2) + D_b2) #X的维度是100*128,D_W2维度是128*32,得到结果维度为100*32X = tf.matmul(X, D_W3) + D_b3 #X的维度是100*32,D_W3维度是32*1,得到结果维度为100*1return X#生成器网络结构
def G(X):X = tf.nn.relu(tf.matmul(X, G_W1) + G_b1) #X的维度是100*10,G_W1维度是10*32,得到结果维度为100*32X = tf.nn.relu(tf.matmul(X, G_W2) + G_b2) #X的维度是100*32,G_W2维度是32*128,得到结果维度为100*128X = tf.nn.sigmoid(tf.matmul(X, G_W3) + G_b3) #X的维度是100*128,G_W3维度是128*784,得到结果维度为100*784return X#real_X是真实样本,random_X是噪音数据,random_Y是生成器生成的伪样本
real_X = tf.placeholder(tf.float32, shape=[batch_size, mnist_dim])
random_X = tf.placeholder(tf.float32, shape=[batch_size, random_dim])
random_Y = G(random_X)#求惩罚项,这个这个惩罚是“软约束”,最终的结果不一定满足这个约束,但是会在约束上下波动。这里Lipschitz约束的C=1
eps = tf.random_uniform([batch_size, 1], minval=0., maxval=1.) #eps是U[0,1]的随机数
X_inter = eps*real_X + (1. - eps)*random_Y  #在真实样本和生成样本之间随机插值,希望这个约束可以“布满”真实样本和生成样本之间的空间
grad = tf.gradients(D(X_inter), [X_inter])[0] #求梯度
grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1)) #求梯度的二范数
grad_pen = 10 * tf.reduce_mean(tf.nn.relu(grad_norm - 1.)) #Lipschitz限制是要求判别器的梯度不超过K,这个loss项是希望判别器的梯度离K(此处K设为1)越近越好#判别器和生成器的损失函数
D_loss = tf.reduce_mean(D(real_X)) - tf.reduce_mean(D(random_Y)) + grad_pen
G_loss = tf.reduce_mean(D(random_Y))  #越接近真实样本越好#判别器和生成器的优化函数
D_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(D_loss, var_list=D_variables)
G_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(G_loss, var_list=G_variables)#创建对话,初始化所有变量
sess = tf.Session()
sess.run(tf.global_variables_initializer())#是否存在“out”文件夹,不存在的话新建一个,存放实验结果
if not os.path.exists('out/'):os.makedirs('out/')for e in range(epochs):for i in range(5): #每轮计算5个batchreal_batch_X,_ = mnist.train.next_batch(batch_size) #随机抓取训练数据中的100个批处理数据点random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim)) #从均匀分布中采样,输出100*10个样本_,D_loss_ = sess.run([D_solver,D_loss], feed_dict={real_X:real_batch_X, random_X:random_batch_X})random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim))_,G_loss_ = sess.run([G_solver,G_loss], feed_dict={random_X:random_batch_X})#每1000轮输出一次当前结果if e % 1000 == 0:print 'epoch %s, D_loss: %s, G_loss: %s'%(e, D_loss_, G_loss_)n_rows = 6check_imgs = sess.run(random_Y, feed_dict={random_X:random_batch_X}).reshape((batch_size, width, height))[:n_rows*n_rows] #由生成器得到伪样本,维度为100*784,reshape为100个28*28的矩阵,取6*6个矩阵构成一张图imgs = np.ones((width*n_rows+5*n_rows+5, height*n_rows+5*n_rows+5)) #203*203的值为1的二维矩阵for i in range(n_rows*n_rows):imgs[5+5*(i%n_rows)+width*(i%n_rows):5+5*(i%n_rows)+width+width*(i%n_rows), 5+5*(i/n_rows)+height*(i/n_rows):5+5*(i/n_rows)+height+height*(i/n_rows)] = check_imgs[i]misc.imsave('out/%s.png'%(e/1000), imgs)

WGAN-GP代码注释相关推荐

  1. FFA-Net:文章理解与代码注释

    FFA-Net: Feature Fusion Attention Network for Single Image Dehazing (AAAI 2020) Pytorch代码(GitHub) 本文 ...

  2. FFA-Net:文章理解于代码注释

    转载自:https://blog.csdn.net/weixin_46773169/article/details/105462644,本文只做个人记录学习使用,版权归原作者所有. github链接: ...

  3. 归并排序(代码注释超详细)

    归并排序: (复制粘贴百度百科没什么意思),简单来说,就是对数组进行分组,然后分组进行排序,排序完最后再整合起来排序! 我看了很多博客,都是写的8个数据呀什么的(2^4,分组方便),我就想着,要是10 ...

  4. 代码注释//_您应该停止编写//的五个代码注释,并且//应该开始的一个注释

    代码注释// 提供来自您最喜欢和最受欢迎的开源项目的示例-React,Angular,PHP,Pandas等! (With examples from your favorite and most p ...

  5. tensorflow笔记:流程,概念和简单代码注释

    tensorflow是google在2015年开源的深度学习框架,可以很方便的检验算法效果.这两天看了看官方的tutorial,极客学院的文档,以及综合tensorflow的源码,把自己的心得整理了一 ...

  6. yolov3网络结构图_目标检测——YOLO V3简介及代码注释(附github代码——已跑通)...

    GitHub: liuyuemaicha/PyTorch-YOLOv3​github.com 注:该代码fork自eriklindernoren/PyTorch-YOLOv3,该代码相比master分 ...

  7. Kotlin------函数和代码注释

    定义函数 Kotlin定义一个函数的风格大致如下 访问控制符 fun 方法名(参数,参数,参数) : 返回值类型{...... } 访问控制符:与Java有点差异,Kotlin的访问范围从大到小分别是 ...

  8. php代码注释处理类库,php代码注释

    代码注释在多人开发的时候非常重要,现象一下,一段代码没有任何主要你去结合运行的效果去看实现的逻辑,那是非常费劲的事. 如果让别人看懂你写的代码,代码注释启动非常重要的作用.一个不会写代码注释的不是一个 ...

  9. 竟有如此沙雕的代码注释!

    点击上方蓝色"程序猿DD",选择"设为星标" 回复"资源"获取独家整理的学习资料! 某站后端代码被"开源",同时刷遍全网 ...

  10. java的注释规范_Java代码注释规范

    1,单行(单行)-简短说明: ///... 单行注释: 代码中的单行注释. 最好在注释前有一个空行,并在其后加上与代码相同的缩进级别. 如果无法完成一行,则应使用块注释. 评论格式: 在行首注释: 在 ...

最新文章

  1. 灯三段调光原理_球泡灯中国能效标识怎么做,GB30255中国能效报告办理要求
  2. httpinvoker远程调用超时_RPC远程过程调用协议工作原理分析
  3. python 重置索引_Pandas的reset_index()重置索引列
  4. 他复读才考上三本,如今让华为开出201万年薪(其实还拒绝了360万offer)
  5. 王成录华为鸿蒙系统,华为手机销量仍在增长!华为王成录:手机会是鸿蒙OS系统的中心...
  6. 2014年自动化的个人感想
  7. JavaScript基本语法2
  8. vos3000 2.1.1.5 安装包及注册机【电销电话机器人源码私有云部署 www.ruikesoft.com 正版授权 抵制盗版】
  9. 计算机系统结构试卷填空,计算机系统结构试卷
  10. Radasm 配置goasm
  11. 2020年大学生编程比赛---ACM、蓝桥杯、天梯赛
  12. 这款软件有多“硬” ——从国内首款基于云架构的三维CAD平台CrownCAD说起
  13. nginx代理百度地图,实现内网展示百度地图
  14. SyntaxError: Non-ASCII character ‘\xe7‘ in file F:/python_code/test/venv/Shan.py on line 7,
  15. 单片机c语言课后题答案,单片机原理及应用(C语言版)习题答案.doc
  16. java对word、Excel、PPT、PDF文件加密
  17. 【流媒体性能测试常用指标】
  18. 小白普及:云主机与传统服务器的区别
  19. java 内联_Java内联类初探
  20. Photoshop如何修改图片的颜色

热门文章

  1. 【0ms优化】剑指 Offer 18. 删除链表的节点
  2. 27行代码AC_迷宫 2017年第八届蓝桥杯A组第一题(暴力、仿迷宫)
  3. java枚举新特性_java回顾之枚举和新特性
  4. gmod的css模块放哪,gmod模式怎么更换?gmod模块安装步骤教程
  5. pythonalert弹窗_python+selenium八:Alert弹窗
  6. python插件使用教程_Python常用扩展插件使用教程解析
  7. oracle收集统计信息sql,Oracle自动统计信息的收集原理及实验
  8. python 多继承 __new___Python3中的__new__方法以及继承不可变类型类的问题
  9. java threadstatus_Thread之一:线程生命周期及六种状态
  10. python调用函数示例_python 动态调用函数实例解析