先放结果

这是通过GAN迭代训练30W次,耗时3小时生成的手写字图片效果,大部分的还是能看出来是数字的。

实现原理

简单说下原理,生成对抗网络需要训练两个任务,一个叫生成器,一个叫判别器,如字面意思,一个负责生成图片,一个负责判别图片,生成器不断生成新的图片,然后判别器去判断哪儿哪儿不行,生成器再不断去改进,不断的像真实的图片靠近。

这就如同一个造假团伙一样,A负责生产,B负责就鉴定,刚开始的时候,两个人都是菜鸟,A随便画了一幅画拿给B看,B说你这不行,然后A再改进,当然需要改进的不止A,随着A的改进,B也得不断提升,B需要发现更细微的差异,直至他们觉得已经没什么差异了(实际肯定还存在差异),他们便决定停止"训练",开始卖吧。

实现代码

# -*- coding: utf-8 -*-

# @author: Awesome_Tang

# @date: 2019-02-22

# @version: python2.7

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

from datetime import datetime

import numpy as np

import os

import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

class Config:

alpha = 1e-2

drop_rate = 0.5 # 保留比例

steps = 300000 # 迭代次数

batch_size = 128 # 每批次训练样本数

epochs = 100 # 训练轮次

num_units = 128

size = 784

noise_size = 100

smooth = 0.01

learning_rate = 1e-4

print_per_step = 1000

class Gan:

def __init__(self):

print('Loading data......')

# 读取MNIST数据集

self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 定义占位符,真实图片和生成的图片

self.real_images = tf.placeholder(tf.float32, [None, Config.size], name='real_images')

self.noise = tf.placeholder(tf.float32, [None, Config.noise_size], name='noise')

self.drop_rate = tf.placeholder('float')

self.train_step()

def generator_graph(self, noise, n_units, out_dim, alpha, reuse=False):

with tf.variable_scope('generator', reuse=reuse):

# Hidden layer

h1 = tf.layers.dense(noise, n_units, activation=None)

# Leaky ReLU

h1 = tf.maximum(alpha * h1, h1)

h1 = tf.layers.dropout(h1, rate=self.drop_rate)

# Logits and tanh output

logits = tf.layers.dense(h1, out_dim, activation=None)

out = tf.tanh(logits)

return out

@staticmethod

def discriminator_graph(image, n_units, alpha, reuse=False):

with tf.variable_scope('discriminator', reuse=reuse):

# Hidden layer

h1 = tf.layers.dense(image, n_units, activation=None)

# Leaky ReLU

h1 = tf.maximum(alpha * h1, h1)

logits = tf.layers.dense(h1, 1, activation=None)

# out = tf.sigmoid(logits)

return logits

def net(self):

# generator

fake_image = self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha)

# discriminator

real_logits = self.discriminator_graph(self.real_images, Config.num_units, Config.alpha)

fake_logits = self.discriminator_graph(fake_image, Config.num_units, Config.alpha, reuse=True)

# discriminator的loss

# 识别真实图片

d_loss_real = tf.reduce_mean(

tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)) * (

1 - Config.smooth))

# 识别生成的图片

d_loss_fake = tf.reduce_mean(

tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))

# 总体loss

d_loss = tf.add(d_loss_real, d_loss_fake)

# generator的loss

g_loss = tf.reduce_mean(

tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)) * (

1 - Config.smooth))

net_vars = tf.trainable_variables()

# generator中的tensor

g_vars = [var for var in net_vars if var.name.startswith("generator")]

# discriminator中的tensor

d_vars = [var for var in net_vars if var.name.startswith("discriminator")]

# optimizer

dis_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(d_loss, var_list=d_vars)

gen_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(g_loss, var_list=g_vars)

return dis_optimizer, gen_optimizer, d_loss, g_loss

def train_step(self):

dis_optimizer, gen_optimizer, d_loss, g_loss = self.net()

print('Training & Evaluating......')

start_time = datetime.now()

sess = tf.Session()

sess.run(tf.global_variables_initializer())

for step in range(Config.steps):

real_image, _ = self.mnist.train.next_batch(Config.batch_size)

real_image = real_image * 2 - 1

# generator的输入噪声

batch_noise = np.random.uniform(-1, 1, size=(Config.batch_size, Config.noise_size))

sess.run(gen_optimizer, feed_dict={self.noise: batch_noise, self.drop_rate: Config.drop_rate})

sess.run(dis_optimizer, feed_dict={self.noise: batch_noise, self.real_images: real_image})

if step % Config.print_per_step == 0:

dis_loss = sess.run(d_loss, feed_dict={self.noise: batch_noise, self.real_images: real_image})

gen_loss = sess.run(g_loss, feed_dict={self.noise: batch_noise, self.drop_rate: 1.})

end_time = datetime.now()

time_diff = (end_time - start_time).seconds

msg = 'Step {:3}k Dis_Loss:{:6.2f}, Gen_Loss:{:6.2f}, Time_Usage:{:6.2f} mins.'

print(msg.format(int(step / 1000), dis_loss, gen_loss, time_diff / 60.))

self.gen_image(sess)

def gen_image(self, sess):

sample_noise = np.random.uniform(-1, 1, size=(25, Config.noise_size))

samples = sess.run(

self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha, reuse=True),

feed_dict={self.noise: sample_noise})

plt.figure(figsize=(8, 8), dpi=80)

for i in range(25):

img = samples[i]

plt.subplot(5, 5, i + 1)

plt.imshow(img.reshape((28, 28)), cmap='Greys_r')

plt.axis('off')

plt.show()

if __name__ == "__main__":

Gan()

Peace~~

python生成手写文字图片_使用生成对抗网络(GAN)生成手写字相关推荐

  1. 手写文字图片识别怎么弄?这几款软件安利快收好

    手写文字一直以来都是一种独特而个人化的表达方式.然而,在数字化时代,我们常常需要将手写文字转化为可编辑的文本,以便进一步编辑.分享或存档. 那么,图片的手写文字如何识别呢?现代技术的进步使得这一任务变 ...

  2. 哪些手写文字图片识别软件好用?分享这三款好用的软件

    在大学毕业季中,我们需要完成一份重要的论文著作,常常需要查阅大量文献资料.有些历史性资料是手写的图片资料,这给查阅和文献引用造成了一定的障碍.这时候,我们可以使用软件将手写图片资料识别成电子档,那你知 ...

  3. 以下这些识别手写文字图片的软件你都知道吗

    最近我身体不舒服去看中医了,但是医生开的药方居然是手写的,我想保存下来,下次再开同样的药都没办法,因为我看不懂他写的字.要是有个软件可以识别手写文字图片就好了. 于是我在网上找了找,果然被我找到了三个 ...

  4. 2020-4-22 深度学习笔记20 - 深度生成模型 5 (有向生成网络--sigmoid信念网络/可微生成器网络/变分自编码器VAE/生产对抗网络GAN/生成矩匹配网络)

    第二十章 深度生成模型 Deep Generative Models 中文 英文 2020-4-17 深度学习笔记20 - 深度生成模型 1 (玻尔兹曼机,受限玻尔兹曼机RBM) 2020-4-18 ...

  5. (五)使用生成对抗网络 (GAN)生成新的时装设计

    目录 介绍 预测新时尚形象的力量 构建GAN 初始化GAN参数和加载数据 从头开始构建生成器 从头开始构建鉴别器 初始化GAN的损失和优化器 下一步 下载源 - 120.7 MB 介绍 DeepFas ...

  6. 生成式对抗网络GAN生成手写数字

    GAN(Generative Adversarial Networks)是较为火热的一种神经网络,具有较多的优势和特点. 一.GAN 1. 原理 源自于零和博弈(zero-sum game),包括生成 ...

  7. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  8. 手写文字识别java_java 手写文字图片识别提取 百度API

    package org.fh.util; import org.json.JSONObject; import java.io.BufferedReader; import java.io.Input ...

  9. 手写识别ocr java,OCR 指的是手写文字技术_学小易找答案

    [判断题]双人舞中的动作必须是两个人一模一样的.( ) [单选题]____________________ engaged in the sale of products using marketin ...

最新文章

  1. 为什么脚本执行一行就不动了_Centos7 批量创建用户账号脚本
  2. faster-rcnn处理图片格式
  3. return 返回值的问题
  4. 子空间:群论的角度解释无监督深度学习
  5. Codeforces Round #592 (Div. 2) G. Running in Pairs 构造(水)
  6. 黑盒测试 白盒测试 题 1
  7. 曼昆经济学原理(微经部分)笔记整理
  8. jQuery file upload测试
  9. CoDeSys开发经验总结
  10. 生成Bernese格式的地球自转参数文件-POLUPD
  11. java gui 文本框_【Java GUI】文本框和文本区
  12. 把大写数字转换成阿拉伯数字后排序
  13. python 计算gdp_菜鸟笔记Python3——数据可视化(三)世界GDP分析
  14. 计算机毕业设计Java影片租赁系统(系统+程序+mysql数据库+Lw文档)
  15. VGA与DVI接口以及HDMI
  16. orm查询方式与优化
  17. 想要彻底掌握placement各种技巧,这个一定可以如你所愿
  18. python用泰勒级数计算圆周率_Python中利用进度条求圆周率
  19. 蓝桥杯练习算法题(矩形切割成正方形)
  20. opencv轮廓相关函数

热门文章

  1. ULTRA社区月度更新报告#3
  2. SAP--SD2-后台基础配置笔记
  3. 天天酷跑 服务器维护中,天天酷跑登录异常怎么办?更新角色列表失败怎么办?...
  4. ArcSDE常见问题总结(三)
  5. 学校学生宿舍无线覆盖解决方案
  6. 计算机思维与逻辑思维的区别,逻辑(思维的规律和规则)_百度百科
  7. ui设计移动端字体适配_超全面的UI设计规范整理汇总(包含iPhone X适配)
  8. java power函数怎么用,Java中的Power函数
  9. 高效会议:明确会议主题,紧紧围绕不偏题
  10. 自动注册登录验证机制