【注1】代码的原文来自以下网址,修改部分及增添注释(基本上都注释了)。修改版整体见最后,原版下方链接,均可以跑通,有问题欢迎交流。生成对抗网络GAN---生成mnist手写数字图像示例(
附代码)_陶将的博客-CSDN博客_gan生成手写数字

【注2】环境要求:≥python3.8,Windows10,pycharm2019,tensorflow2.70

如果是tensorflow版本问题可以考虑升级或使用原代码

一.调用库

import tensorflow as tf
tf.compat.v1.disable_eager_execution()import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt

这里由于2.x版本的语法进行了修改,导致原文中部分代码无法运行,compat.v1使得可以让2.x直接运行1.x版本的代码。

【注3】正文的很多部分例如layers,dense等都需要在tf和包中间加入compat.v1

【注4】examples包我是自己下载的,csdn上已经有大佬上传了,记得找一个日期前一点的。

二.初始化准备

# 初始化准备
BATCH_SIZE = 64   # 每一轮训练的数量
UNITS_SIZE = 128  # 生成器隐藏层的参数
LEARNING_RATE = 0.001  # 学习速率
EPOCH = 300            # 训练迭代轮数
SMOOTH = 0.1           # 标签平滑
# 读入mnist数据,理论上1.9版本之后数据集不会自动下载,但是这个代码运行的时候是会下载出数据集的。
mnist = input_data.read_data_sets('/mnist_data/', one_hot=True)

【注5】参数可以随自己修改,本地cpu的要是显卡太垃圾(和我一样的话)建议找云服务器

三.生成器代码详解

# 生成器
def generatorModel(noise_img, units_size, out_size, alpha=0.01):# 参数解析# noise_img:生成器生成噪声图片# units_size: 隐藏层单元数# out_size:生成器输出图片大小# alpha:激活函数的系数with tf.compat.v1.variable_scope('generator'):# 创建一个空间generator,使得在这个空间当中,变量可以重复使用# 全连接,连接输入和隐藏层FC = tf.compat.v1.layers.dense(noise_img, units_size)# 隐藏层的激活函数,之后的dropout方法是为了防止发生过拟合的现象reLu = tf.nn.leaky_relu(FC, alpha)drop = tf.compat.v1.layers.dropout(reLu, rate=0.2)# 全连接,连接隐藏层和输出层,输出层的激活函数选择tanhlogits = tf.compat.v1.layers.dense(drop, out_size)outputs = tf.tanh(logits)return logits, outputs

四.判别器代码详解

# 判别模型
def discriminatorModel(images, units_size, alpha=0.01, reuse=False):# 参数详解# images:真实图片# reuse:是否重复占用空间with tf.compat.v1.variable_scope('discriminator', reuse=reuse):# 全连接FC = tf.compat.v1.layers.dense(images, units_size)# 隐藏层激活函数reLu = tf.nn.leaky_relu(FC, alpha)# 全连接,这里输出层的激活函数改为sigmoidlogits = tf.compat.v1.layers.dense(reLu, 1)outputs = tf.sigmoid(logits)return logits, outputs

【注6】这里可以看出,判别器和生成器的主要差别在于输出层的激活函数

五.损失函数代码详解

def loss_function(real_logits, fake_logits, smooth):# 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.ones_like(fake_logits) * (1 - smooth)))# 判别器识别生成器产出的图片,希望识别出来的标签为0fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.zeros_like(fake_logits)))# 判别器判别真实图片,希望判别出来的标签为1real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,labels=tf.ones_like(real_logits) * (1 - smooth)))# 判别器总lossD_loss = tf.add(fake_loss, real_loss)return G_loss, fake_loss, real_loss, D_loss

【注7】

tf.nn.sigmoid_cross_entropy_with_logits

这个方法对传入的参数先使用sigmoid进行计算,然后在计算他们的交叉熵损失,使得结果不会溢出。

六.优化器代码详解

# 优化器
def optimizer(G_loss, D_loss, learning_rate):# 首先获取网络结构中的参数,也就是判别器和生成器的变量,在后面的最小化损失时修改train_var = tf.compat.v1.trainable_variables()G_var = [var for var in train_var if var.name.startswith('generator')]D_var = [var for var in train_var if var.name.startswith('discriminator')]# 因为GAN中一共训练了两个网络,所以分别对G和D进行优化# 这里使用AdamOptimizer方法来减少损失(娘希匹的2.x这玩意怎么直接用),动态调整每个参数的学习速率。G_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(G_loss, var_list=G_var)D_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(D_loss, var_list=D_var)return G_optimizer, D_optimizer

七.训练代码详解

【注8】以下两个部分均定义在一个def下!

def train(mnist):# 前期准备,该流程和上面的逻辑顺序相同# 真实图片的大小image_size = mnist.train.images[0].shape[0]# 定义接收输入的方法,占位符placeholder来获得输入的数据real_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])fake_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])# 生成器参数解释# 将噪声,生成器隐藏层节点数,真实图片大小传入生成器(这样搞可以生成大小一样的图片)G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size)# 判别器:先传入参数,给真实图片打分,再给生成图片打分。# D对真实图像的判别real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE)# D对G生成图像的判别,为其打分fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True)# 计算损失函数G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH)# 优化G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE)# 保存生成器变量saver = tf.compat.v1.train.Saver()step = 0
   with tf.compat.v1.Session() as session:# 初始化模型的参数session.run(tf.compat.v1.global_variables_initializer())for epoch in range(EPOCH):for batch_i in range(mnist.train.num_examples // BATCH_SIZE):batch_image, _ = mnist.train.next_batch(BATCH_SIZE)# 对图像像素进行scale,tanh的输出结果为(-1,1),real和fake图片共享参数batch_image = batch_image * 2 - 1# 生成模型的输入噪声(图片)noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size))# 先训练生成器,在训练判别器session.run(G_optimizer, feed_dict={fake_images: noise_image})session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image})step = step + 1# 判别器D的损失(每一轮训练之后)loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image})# D对真实图片(训练时)loss_real = session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})# D对生成图片(训练时)loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})# 生成器的损失loss_G = session.run(G_loss, feed_dict={fake_images: noise_image})print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G',loss_G)model_path = os.getcwd() + os.sep + "mnist.model"# 存储saver.save(session, model_path, global_step=step)

八.训练模型运行结果

下面是代码成功运行的图片,300轮的化大概24分钟左右(我是垃圾显卡2g)

从这里可以看出,还是比较模糊的,在不大的改变代码的情况下,只能增加迭代次数。

九.完整代码直接运行版

import tensorflow as tftf.compat.v1.disable_eager_execution()import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt# 初始化准备
BATCH_SIZE = 64  # 每一轮训练的数量
UNITS_SIZE = 128  # 生成器隐藏层的参数
LEARNING_RATE = 0.001  # 学习速率
EPOCH = 300  # 训练迭代轮数
SMOOTH = 0.1  # 标签平滑
# 读入mnist数据,理论上1.9版本之后数据集不会自动下载,但是这个代码运行的时候是会下载出数据集的。
mnist = input_data.read_data_sets('/mnist_data/', one_hot=True)# 生成器
def generatorModel(noise_img, units_size, out_size, alpha=0.01):# 参数解析# noise_img:生成器生成噪声图片# units_size: 隐藏层单元数# out_size:生成器输出图片大小# alpha:激活函数的系数with tf.compat.v1.variable_scope('generator'):# 创建一个空间generator,使得在这个空间当中,变量可以重复使用# 全连接,连接输入和隐藏层FC = tf.compat.v1.layers.dense(noise_img, units_size)# 隐藏层的激活函数,之后的dropout方法是为了防止发生过拟合的现象reLu = tf.nn.leaky_relu(FC, alpha)drop = tf.compat.v1.layers.dropout(reLu, rate=0.2)# 全连接,连接隐藏层和输出层,输出层的激活函数选择tanhlogits = tf.compat.v1.layers.dense(drop, out_size)outputs = tf.tanh(logits)return logits, outputs# 判别模型
def discriminatorModel(images, units_size, alpha=0.01, reuse=False):# 参数详解# images:真实图片# reuse:是否重复占用空间with tf.compat.v1.variable_scope('discriminator', reuse=reuse):# 全连接FC = tf.compat.v1.layers.dense(images, units_size)# 隐藏层激活函数reLu = tf.nn.leaky_relu(FC, alpha)# 全连接,这里输出层的激活函数改为sigmoidlogits = tf.compat.v1.layers.dense(reLu, 1)outputs = tf.sigmoid(logits)return logits, outputs# 损失函数
"""
判别器的目的是:
1. 对于真实图片,D要为其打上标签1
2. 对于生成图片,D要为其打上标签0
生成器的目的是:对于生成的图片,G希望D打上标签1
"""def loss_function(real_logits, fake_logits, smooth):# 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.ones_like(fake_logits) * (1 - smooth)))# 判别器识别生成器产出的图片,希望识别出来的标签为0fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,labels=tf.zeros_like(fake_logits)))# 判别器判别真实图片,希望判别出来的标签为1real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,labels=tf.ones_like(real_logits) * (1 - smooth)))# 判别器总lossD_loss = tf.add(fake_loss, real_loss)return G_loss, fake_loss, real_loss, D_loss# 优化器
def optimizer(G_loss, D_loss, learning_rate):# 首先获取网络结构中的参数,也就是判别器和生成器的变量,在后面的最小化损失时修改train_var = tf.compat.v1.trainable_variables()G_var = [var for var in train_var if var.name.startswith('generator')]D_var = [var for var in train_var if var.name.startswith('discriminator')]# 因为GAN中一共训练了两个网络,所以分别对G和D进行优化# 这里使用AdamOptimizer方法来减少损失(娘希匹的2.x这玩意怎么直接用),动态调整每个参数的学习速率。G_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(G_loss, var_list=G_var)D_optimizer = tf.compat.v1.train.AdadeltaOptimizer(learning_rate).minimize(D_loss, var_list=D_var)return G_optimizer, D_optimizer# 训练代码
def train(mnist):# 前期准备,该流程和上面的逻辑顺序相同# 真实图片的大小image_size = mnist.train.images[0].shape[0]# 定义接收输入的方法,占位符placeholder来获得输入的数据real_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])fake_images = tf.compat.v1.placeholder(tf.float32, [None, image_size])# 生成器参数解释# 将噪声,生成器隐藏层节点数,真实图片大小传入生成器(这样搞可以生成大小一样的图片)G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size)# 判别器:先传入参数,给真实图片打分,再给生成图片打分。# D对真实图像的判别real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE)# D对G生成图像的判别,为其打分fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True)# 计算损失函数G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH)# 优化G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE)# 保存生成器变量saver = tf.compat.v1.train.Saver()step = 0with tf.compat.v1.Session() as session:# 初始化模型的参数session.run(tf.compat.v1.global_variables_initializer())for epoch in range(EPOCH):for batch_i in range(mnist.train.num_examples // BATCH_SIZE):batch_image, _ = mnist.train.next_batch(BATCH_SIZE)# 对图像像素进行scale,tanh的输出结果为(-1,1),real和fake图片共享参数batch_image = batch_image * 2 - 1# 生成模型的输入噪声(图片)noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size))# 先训练生成器,在训练判别器session.run(G_optimizer, feed_dict={fake_images: noise_image})session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image})step = step + 1# 判别器D的损失(每一轮训练之后)loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images: noise_image})# D对真实图片(训练时)loss_real = session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})# D对生成图片(训练时)loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})# 生成器的损失loss_G = session.run(G_loss, feed_dict={fake_images: noise_image})print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G',loss_G)model_path = os.getcwd() + os.sep + "mnist.model"# 存储saver.save(session, model_path, global_step=step)def main(argv=None):train(mnist)if __name__ == '__main__':tf.compat.v1.app.run()

生成式对抗网络实战(一)——手写数字生成(CPU本地版)完整代码加详解相关推荐

  1. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  2. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  3. 完整代码及解析!!手写数字识别系统(手写数字测试识别 + pytoch实现 + 完整代码及解析)

    基于深度学习的手写数字识别系统 一.实验目的 ​ 1.任选实验环境及深度学习框架,实现手写数字识别系统: ​ 2.掌握所采用的深度血迹框架构建方式. 二.实验理论基础 1.MNIST数据集 ​ MNI ...

  4. 从手写数字识别入门深度学习丨MNIST数据集详解

    就像无数人从敲下"Hello World"开始代码之旅一样,许多研究员从"MNIST数据集"开启了人工智能的探索之路. MNIST数据集(Mixed Natio ...

  5. Pytorch:GAN生成对抗网络实现MNIST手写数字的生成

    github:https://github.com/SPECTRELWF/pytorch-GAN-study 个人主页:liuweifeng.top:8090 网络结构 最近在疯狂补深度学习一些基本架 ...

  6. 使用Keras训练Lenet网络来进行手写数字识别

    使用Keras训练Lenet网络来进行手写数字识别 这篇博客将介绍如何使用Keras训练Lenet网络来进行手写数字识别. LeNet架构是深度学习中的一项开创性工作,演示了如何训练神经网络以端到端的 ...

  7. DL之NN/Average_Darkness/SVM:手写数字图片识别(本地数据集50000训练集+数据集加4倍)比较3种算法Average_Darkness、SVM、NN各自的准确率

    DL之NN/Average_Darkness/SVM:手写数字图片识别(本地数据集50000训练集+数据集加4倍)比较3种算法Average_Darkness.SVM.NN各自的准确率 目录 数据集下 ...

  8. AI常用框架和工具丨11. 基于TensorFlow(Keras)+Flask部署MNIST手写数字识别至本地web

    代码实例,基于TensorFlow+Flask部署MNIST手写数字识别至本地web,希望对您有所帮助. 文章目录 环境说明 文件结构 模型训练 本地web创建 实现效果 环境说明 操作系统:Wind ...

  9. 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

    文章目录 一.前期工作 1. 设置GPU 2. 定义训练参数 二.什么是生成对抗网络 1. 简单介绍 2. 应用领域 三.网络结构 四.构建生成器 五.构建鉴别器 六.训练模型 1. 保存样例图片 2 ...

最新文章

  1. 强大的JQuery(三)--操作html与遍历
  2. python管理系统-基于Python实现用户管理系统
  3. 关于 tomcat 集群中 session 共享的三种方法
  4. 计算机游戏的英语怎么写,电脑游戏英语怎么写
  5. navicat不同数据库数据传输
  6. kettle查询mysql获取uuid_java中调用kettle转换文件
  7. 设计模式(Design Patterns)
  8. Java 异常类层次结构
  9. 【开发/调试工具】【串口工具】不同串口软件如何生成带时间戳的日志
  10. 软件测试必须知道的精华总结
  11. python实例013--定义一个矩形类
  12. 面试问反射 你能跟面试官聊多少呢
  13. 宝塔同时安装苹果cms海洋cms_苹果cms和海洋cms通用的百度主动推送工具
  14. 什么东西可以改善睡眠,可以试试这些助眠好物改善睡眠
  15. 功能强大的黑科技APP,各种免费资源一应俱全!
  16. 微信小程序预览 word、excel、ppt、pdf 等文件
  17. MII、 RMII、 GMII、 RGMII 接口介绍
  18. Spring Boot (Filter)过滤器的实现以及使用场景
  19. AutoJs学习-投币小游戏
  20. 为什么溺水事故无法“清零”?

热门文章

  1. 数字图像处理:图像分割 人工智能算法在图像处理中的应用
  2. 游戏设计小议:一 游戏的娱乐性与电脑游戏的特点
  3. Kubernete(k8s)—资源清单
  4. Vue + Spring Boot 项目实战(十七):后台角色、权限与菜单分配
  5. android7.0 比较特别的功能,安卓7.0有什么新功能 Android7.0新功能全面一览
  6. 微信小程序开发纪实-菜鸟新手入门
  7. apache服务器的配置
  8. 啤酒企业应用套件为Android最流行的
  9. H - Weekend(folyd+全排列)
  10. php发帖功能源代码,discuz关于发帖数据保存功能