今天看到paperweekly上有人分享了一个WGAN-GP的实现,是以MNIST为数据集,代码简洁,结构清晰。我最近也在看GAN的相关内容,就下载下来做个参考。 
代码地址:https://github.com/bojone/gan/

对于这个基于tensorflow实现的代码,我对其进行了简单的注释

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
from scipy import misc,ndimage#读入本地的MNIST数据集,该函数为mnist专用
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)batch_size = 100 #每个batch的大小
width,height = 28,28  #每张图片包含28*28个像素点
mnist_dim = width*height #用一个数字数组表示一张图,那么这个数组展开成向量的长度就是28*28=784
random_dim = 10 #每张图表示一个数字,从0到9
epochs = 1000000  #共100万轮def my_init(size): #从[-0.05,0.05]的均匀分布中采样得到维度是size的输出return tf.random_uniform(size, -0.05, 0.05)#判别器相关参数设定
D_W1 = tf.Variable(my_init([mnist_dim, 128])) #784*128
D_b1 = tf.Variable(tf.zeros([128])) #长度为128的一维张量,值均为0
D_W2 = tf.Variable(my_init([128, 32]))
D_b2 = tf.Variable(tf.zeros([32]))
D_W3 = tf.Variable(my_init([32, 1]))
D_b3 = tf.Variable(tf.zeros([1]))
D_variables = [D_W1, D_b1, D_W2, D_b2, D_W3, D_b3]#生成器相关参数设定
G_W1 = tf.Variable(my_init([random_dim, 32]))
G_b1 = tf.Variable(tf.zeros([32]))
G_W2 = tf.Variable(my_init([32, 128]))
G_b2 = tf.Variable(tf.zeros([128]))
G_W3 = tf.Variable(my_init([128, mnist_dim]))
G_b3 = tf.Variable(tf.zeros([mnist_dim]))
G_variables = [G_W1, G_b1, G_W2, G_b2, G_W3, G_b3]#判别器网络结构
def D(X):X = tf.nn.relu(tf.matmul(X, D_W1) + D_b1) #X的维度是100*784,D_W1维度是784*128,得到结果维度为100*128X = tf.nn.relu(tf.matmul(X, D_W2) + D_b2) #X的维度是100*128,D_W2维度是128*32,得到结果维度为100*32X = tf.matmul(X, D_W3) + D_b3 #X的维度是100*32,D_W3维度是32*1,得到结果维度为100*1return X#生成器网络结构
def G(X):X = tf.nn.relu(tf.matmul(X, G_W1) + G_b1) #X的维度是100*10,G_W1维度是10*32,得到结果维度为100*32X = tf.nn.relu(tf.matmul(X, G_W2) + G_b2) #X的维度是100*32,G_W2维度是32*128,得到结果维度为100*128X = tf.nn.sigmoid(tf.matmul(X, G_W3) + G_b3) #X的维度是100*128,G_W3维度是128*784,得到结果维度为100*784return X#real_X是真实样本,random_X是噪音数据,random_Y是生成器生成的伪样本
real_X = tf.placeholder(tf.float32, shape=[batch_size, mnist_dim])
random_X = tf.placeholder(tf.float32, shape=[batch_size, random_dim])
random_Y = G(random_X)#求惩罚项,这个这个惩罚是“软约束”,最终的结果不一定满足这个约束,但是会在约束上下波动。这里Lipschitz约束的C=1
eps = tf.random_uniform([batch_size, 1], minval=0., maxval=1.) #eps是U[0,1]的随机数
X_inter = eps*real_X + (1. - eps)*random_Y  #在真实样本和生成样本之间随机插值,希望这个约束可以“布满”真实样本和生成样本之间的空间
grad = tf.gradients(D(X_inter), [X_inter])[0] #求梯度
grad_norm = tf.sqrt(tf.reduce_sum((grad)**2, axis=1)) #求梯度的二范数
grad_pen = 10 * tf.reduce_mean(tf.nn.relu(grad_norm - 1.)) #Lipschitz限制是要求判别器的梯度不超过K,这个loss项是希望判别器的梯度离K(此处K设为1)越近越好#判别器和生成器的损失函数
D_loss = tf.reduce_mean(D(real_X)) - tf.reduce_mean(D(random_Y)) + grad_pen
G_loss = tf.reduce_mean(D(random_Y))  #越接近真实样本越好#判别器和生成器的优化函数
D_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(D_loss, var_list=D_variables)
G_solver = tf.train.AdamOptimizer(1e-4, 0.5).minimize(G_loss, var_list=G_variables)#创建对话,初始化所有变量
sess = tf.Session()
sess.run(tf.global_variables_initializer())#是否存在“out”文件夹,不存在的话新建一个,存放实验结果
if not os.path.exists('out/'):os.makedirs('out/')for e in range(epochs):for i in range(5): #每轮计算5个batchreal_batch_X,_ = mnist.train.next_batch(batch_size) #随机抓取训练数据中的100个批处理数据点random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim)) #从均匀分布中采样,输出100*10个样本_,D_loss_ = sess.run([D_solver,D_loss], feed_dict={real_X:real_batch_X, random_X:random_batch_X})random_batch_X = np.random.uniform(-1, 1, (batch_size, random_dim))_,G_loss_ = sess.run([G_solver,G_loss], feed_dict={random_X:random_batch_X})#每1000轮输出一次当前结果if e % 1000 == 0:print 'epoch %s, D_loss: %s, G_loss: %s'%(e, D_loss_, G_loss_)n_rows = 6check_imgs = sess.run(random_Y, feed_dict={random_X:random_batch_X}).reshape((batch_size, width, height))[:n_rows*n_rows] #由生成器得到伪样本,维度为100*784,reshape为100个28*28的矩阵,取6*6个矩阵构成一张图imgs = np.ones((width*n_rows+5*n_rows+5, height*n_rows+5*n_rows+5)) #203*203的值为1的二维矩阵for i in range(n_rows*n_rows):imgs[5+5*(i%n_rows)+width*(i%n_rows):5+5*(i%n_rows)+width+width*(i%n_rows), 5+5*(i/n_rows)+height*(i/n_rows):5+5*(i/n_rows)+height+height*(i/n_rows)] = check_imgs[i]misc.imsave('out/%s.png'%(e/1000), imgs)

转自 https://blog.csdn.net/qq_20943513/article/details/73129308

WGAN-GP 学习笔记相关推荐

  1. C++ STL学习笔记(3) 分配器Allocator,OOP, GP简单介绍

    继续学习侯捷老师的课程! 在前面的博客<C++ STL学习笔记(2) 容器结构与分类>中介绍了STL中常用到的容器以及他们的使用方法,在我们使用容器的时候,背后需要一个东西支持对内存的使用 ...

  2. OP-TEE内核学习笔记(一)(安全存储)—— 安全存储 GP API

    存储文件基础操作 一. 安全存储GP API 1.1 `TEE_CreatePersistentObject` 1.2 `TEE_CloseAndDeletePersistentObject` 1.3 ...

  3. 《繁凡的深度学习笔记》前言、目录大纲 一文让你完全弄懂深度学习所有基础(DL笔记整理系列)

    <繁凡的深度学习笔记>前言.目录大纲 (DL笔记整理系列) 一文弄懂深度学习所有基础 ! 3043331995@qq.com https://fanfansann.blog.csdn.ne ...

  4. 【学习笔记】超简单的快速数论变换(NTT)(FFT的优化)(含全套证明)

    整理的算法模板合集: ACM模板 目录 一.前置知识 二.快速数论变换(NTT) 三.NTT证明(和FFT的关系) 四.NTT模板 数组形式的实现 vector形式的实现 点我看多项式全家桶(●^◡_ ...

  5. main 函数解析(二)—— Linux-0.11 学习笔记(六)

    main函数解析(二)--Linux-0.11 学习笔记(六) 4.6 blk_dev_init函数 void blk_dev_init(void) {int i;for (i=0 ; i<NR ...

  6. 《多元统计分析》学习笔记之聚类分析

    鄙人学习笔记 PS:对不起,原本想简单写写,总结一下,不想截那么多图,但写着写着觉得都挺想写的,就越写越多,越截越多.... 文章目录 聚类分析 聚类分析的基本思想 相似性度量 类和类的特征 系统聚类 ...

  7. RN学习笔记01:概述、特点与环境搭建

    RN学习笔记01:概述.特点与环境搭建 一.RN概述 React Native(简称RN)是Facebook于2015年4月开源的跨平台移动应用开发框架,是Facebook早先开源的JS框架 Reac ...

  8. C++STL学习笔记(4) 分配器(Allocator)

    在前面的博客<C++ STL学习笔记(3) 分配器Allocator,OOP, GP简单介绍>中,简单的介绍了分配器再STL的容器中所担当的角色,这一节对STL六大部件之一的分配器进行详细 ...

  9. cs224w(图机器学习)2021冬季课程学习笔记16 Community Detection in Networks

    诸神缄默不语-个人CSDN博文目录 cs224w(图机器学习)2021冬季课程学习笔记集合 文章目录 1. Community Detection in Networks 2. Network Com ...

  10. 【操作系统】CSAPP学习笔记

    CSAPP学习笔记 前言 在阅读本书前,最好先了解一下书本的结构,然后根据结构,网上查查网评.最好能找到一些最佳阅读技巧.可以给自己定一个大一点的目标,比如,期望读完这本书,可以自己设计一个操作系统. ...

最新文章

  1. linux查看和修改PATH环境变量的方法
  2. λ表达式_Java 8新特性:学习如何使用Lambda表达式,一看必懂
  3. Vue.js响应式原理
  4. Nginx----实现https站点
  5. html 指定对象为块元素,html内联(行内)元素、块级(块状)元素和行内块元素分类...
  6. 怎么引jsp包_电机引接线的制作流程防护等级
  7. python写入mysql数据库_python调用http接口,数据写入mysql数据库并下载录音文件
  8. 计算机英语小短文单词易懂,求计算机英语短文译文。。。。急急急!悬赏10
  9. java对集合的操作_Java中对List集合的常用操作
  10. Linux杀100个进程,在linux bash中杀死一个进程子树
  11. HTML和CSS面试题—整理过的48题,关注收藏,持续更新
  12. 传奇私服网站php源码,传奇h5私服源码+教程
  13. 风口的猪-中国牛市(动态规划)
  14. 杨建:网站加速--实例分析篇
  15. 技嘉b365m小雕驱动工具_【黑苹果】技嘉B365M小雕+i5 9400F+RX590EFI分享
  16. Android手势密码探索
  17. 自定义View-仿QQ运动步数进度效果(完整代码)
  18. python获取每月的最后一天
  19. 医学图像——CT值(Hu值)
  20. 中国有句俗语叫“三天打鱼两天晒网”。某人从2010年1月1日起开始“三天打鱼两天晒网”, 问这个人在以后的某一天中是“打鱼”还是“晒网”。用C或C++语言/java/python实现程序解决问题

热门文章

  1. zoj 3204 Connect them kruskal
  2. Leecode31. 下一个排列——Leecode大厂热题100道系列
  3. 【一起去大厂系列】深入理解MySQL中where 1 = 1的用处
  4. 【已解决】[Error] reference to ‘min‘ is ambiguous
  5. mysql多个on_在多个查询中插入多行的MySQL ON DUPLICATE KEY UPDATE
  6. Linux安装及管理程序——RPM和yum学会装软件超简单
  7. spring cloud 熔断_Spring Cloud 熔断器/断路器 Hystrix
  8. C语言面试题分享(6)
  9. jxl生成表格(合并单元格,字体,样式)
  10. 计算机 留学推荐信,计算机专业留学推荐信范文