之前就对GAN这项技术很感兴趣,可是后面一直没有找到时间研究一下,今天找来了一个很不错的例子学习实践了一下,简单来记录一下自己的实践,具体的代码如下:

#!usr/bin/env python
#encoding:utf-8
from __future__ import division'''
__Author__:沂水寒城
功能: 基于GAN的手写数字生成实践
'''import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tensorflow.examples.tutorials.mnist import input_data#设置基本的参数信息
mb_size = 32
X_dim = 784
z_dim = 64
h_dim = 128
lr = 1e-3
m = 5
lam = 1e-3
gamma = 0.5
k_curr = 0
if not os.path.exists('result/'):os.makedirs('result/')
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)def numberPloter(samples):'''数字图像绘制'''figure = plt.figure(figsize=(8, 8))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return figuredef xavier_init(size):'''初始化'''in_dim = size[0]xavier_stddev = 1. / tf.sqrt(in_dim / 2.)return tf.random_normal(shape=size, stddev=xavier_stddev)X = tf.placeholder(tf.float32, shape=[None, X_dim])
z = tf.placeholder(tf.float32, shape=[None, z_dim])
k = tf.placeholder(tf.float32)
D_W1 = tf.Variable(xavier_init([X_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
D_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
D_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
G_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
theta_G = [G_W1, G_W2, G_b1, G_b2]
theta_D = [D_W1, D_W2, D_b1, D_b2]def sample_z(m, n):'''随机数'''return np.random.uniform(-1., 1., size=[m, n])def G(z):'''定义两个网络'''G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob)return G_probdef D(X):'''定义两个网络'''D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)X_recon = tf.matmul(D_h1, D_W2) + D_b2return tf.reduce_mean(tf.reduce_sum((X - X_recon)**2, 1))# 计算损失
G_sample = G(z)
D_real = D(X)
D_fake = D(G_sample)
D_loss = D_real - k*D_fake
G_loss = D_fake
D_solver=(tf.train.AdamOptimizer(learning_rate=lr).minimize(D_loss, var_list=theta_D))
G_solver=(tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=theta_G))
sess = tf.Session()
sess.run(tf.global_variables_initializer())# 迭代计算一百万次,每1000次绘制一张图片
num = 0
for it in range(1000000):X_mb, _ = mnist.train.next_batch(mb_size)_, D_real_curr = sess.run([D_solver, D_real],feed_dict={X: X_mb, z: sample_z(mb_size, z_dim), k: k_curr})_, D_fake_curr = sess.run([G_solver, D_fake],feed_dict={X: X_mb, z: sample_z(mb_size, z_dim)})k_curr = k_curr + lam * (gamma*D_real_curr - D_fake_curr)if it % 1000 == 0:measure = D_real_curr + np.abs(gamma*D_real_curr - D_fake_curr)print('Iter-{}; Convergence measure: {:.4}'.format(it, measure))samples = sess.run(G_sample, feed_dict={z: sample_z(16, z_dim)})fig = plot(samples)plt.savefig('result/{}.png'.format(str(num).zfill(3)), bbox_inches='tight')num += 1plt.close(fig)

这是一个很简单实用的例子,基于GAN来生成手写数字,关于各部分的代码作用,我在具体的代码里面已经加入了相应的注释,下面我们来简单看一下输出的结果:

。。。。。。。。。。。。。。。。。。

上面是展示了1000张图片的前100张,和后面将近100张左右的结果缩略图,这里给出来第一张和最后一张:

第一张:

最后一张:

之后找时间继续学习,欢迎交流!

基于GAN的手写数字生成实践相关推荐

  1. 深度学习之基于GAN实现手写数字生成

    在弄毕设的时候,室友的毕设是基于DCGAN实现音乐的自动生成.那是第一次接触对抗神经网络,当时听室友的描述就是两个CNN,一个生成一个监测,在互相博弈. 最近我关注的一个大神在弄有关于GAN的东西,所 ...

  2. 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

    文章目录 一.前期工作 1. 设置GPU 2. 定义训练参数 二.什么是生成对抗网络 1. 简单介绍 2. 应用领域 三.网络结构 四.构建生成器 五.构建鉴别器 六.训练模型 1. 保存样例图片 2 ...

  3. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  4. 深度学习之基于DCGAN实现手写数字生成

    该篇文章与上篇文章内容相差不多,但是主要的网络结构不同,上篇文章采用的是GAN网络结构,而这篇文章采用的是DCGAN网络结构.两者的差异在于以下几点: (1)使用卷积和去卷积代替池化层. (2)在生成 ...

  5. 机器学习算法(九): 基于线性判别LDA模型的分类(基于LDA手写数字分类实践)

    机器学习算法(九): 基于线性判别模型的分类 1.前言:LDA算法简介和应用 1.1.算法简介 线性判别模型(LDA)在模式识别领域(比如人脸识别等图形图像识别领域)中有非常广泛的应用.LDA是一种监 ...

  6. 基于CNN的手写数字识别

    基于CNN的手写数字识别 文章目录 基于CNN的手写数字识别 零. 写在之前 壹. 聊聊CNN 01. 什么是CNN 02. 为什么要有CNN 03. CNN模型 3.1 卷积层 3.2 池化层 3. ...

  7. DL之RBM:基于RBM实现手写数字图片识别提高准确率

    DL之RBM:基于RBM实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 import numpy as np import matplotlib.pyplot as pl ...

  8. TF之LiR:基于tensorflow实现手写数字图片识别准确率

    TF之LiR:基于tensorflow实现手写数字图片识别准确率 目录 输出结果 代码设计 输出结果 Extracting MNIST_data\train-images-idx3-ubyte.gz ...

  9. 基于tensorflow的手写数字识别

    基于tensorflow的手写数字识别 数据准备 引入包 加载数据 查看数据信息 查看一张图片 数据预处理 搭建网络模型 模型的预测与评价 模型的展示 对一张图片进行预测 准确率 数据准备 引入包 i ...

  10. ML之K-means:基于(完整的)手写数字图片识别数据集利用K-means算法实现图片聚类

    ML之K-means:基于(完整的)手写数字图片识别数据集利用K-means算法实现图片聚类 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 metrics.adjusted_ran ...

最新文章

  1. php 上传乱码_如何解决php文件上传中文乱码问题
  2. sap-通过定义物料组的评估类-设置无物料号的费用采购
  3. oracle ebs技术开发,Oracle EBS应用架构技术方案.pdf
  4. C++对象内存布局--③测试多继承中派生类的虚函数在哪一张虚函数表中
  5. SQL with(nolock)详解
  6. 不同国家的视力表也不一样!| 今日趣图
  7. Java中Integer.parseInt()用法
  8. android文件系统只读,android
  9. 月光宝盒游戏机MAME街机模拟器方案源码项目解析----米饭模拟器(2)
  10. html中写switch,switch语句使用
  11. WPS公式编辑器快捷键
  12. 谷歌学术打不开的解决办法
  13. 树莓派做网络代理_【树莓派】设置代理服务器联网
  14. Java编程:将具有父子关系的数据库表数据转换为树形结构,支持无限层级
  15. vivo市场API事件上报对接
  16. 【Chrome Extensions】实现一个可以下载图片的Chrome插件
  17. 重庆思庄-[Oracle] SYSAUX表空间WRH$表的清理
  18. 专升本高数——第九章 无穷级数【学习笔记】
  19. 地理探测器的下载和使用
  20. OSG( OpenSceneGraphic)

热门文章

  1. MyRocks之备份恢复
  2. Lambda表达式实现有限状态机
  3. Extra Credits: Project Ten Dollar 10
  4. [导入]网络安全工作者的必杀技
  5. SpringCloud学习(SPRINGCLOUD微服务实战)一
  6. 基本数据类型与引用数据类型
  7. 在做简单网页时,遇到的一些js问题
  8. c语言数组的概念和指针的加减使用
  9. 值得推荐的C/C++框架和库(转)
  10. [android开发IDE]adt-bundle-windows-x86的一个bug:无法解析.rs文件--------rs_core.rsh file not found...