GAN生成手写字体识别
这篇文我主要是利用GAN生成手写字体,原理和实现方法和之前的GAN生成抛物线是一样的点击打开链接,我们直接看代码。
首先我是定义了一个可视化的函数
import matplotlib.pyplot as plt
def vis_img(batch_size,samples):fig,axes = plt.subplots(figsize=(7,7),nrows=8,ncols=8,sharey=True,sharex=True)for ax,img in zip(axes.flatten(),samples[batch_size]):ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)im = ax.imshow(img.reshape((28, 28)), cmap='Greys_r')plt.show()return fig, axes
下面就是实现方法:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#from utils import vis_imgmnist = input_data.read_data_sets('./data/mnist',one_hot=True)def generator(inputs,name,reuse=False):# 输入值# name 表示scope的name# reuse表示是否重用变量with tf.variable_scope(name,reuse=reuse) as scope:fc1 = tf.layers.dense(inputs,units=128,activation=None)#bn1 = tf.layers.batch_normalization(fc1)#ac1= tf.nn.relu(bn1)ac1 = tf.maximum(0.01*fc1,fc1)fc2 = tf.layers.dense(ac1, units=256,activation=None)#bn2 = tf.layers.batch_normalization(fc2)#ac2 = tf.nn.relu(bn2)ac2 = tf.maximum(0.01 * fc2, fc2)# 这个地方不需要激活层,fc3 = tf.layers.dense(ac2, units=784,activation=tf.nn.tanh)return fc3
def discriminator(inputs,name,alpha=0.01,reuse=False):with tf.variable_scope(name,reuse=reuse):fc1 = tf.layers.dense(inputs,256,activation=None)ac1 = tf.maximum(alpha * fc1, fc1)fc2 = tf.layers.dense(ac1, 256, activation=None)ac2 = tf.maximum(alpha * fc2, fc2)logits = tf.layers.dense(ac2, 2, activation=None)out = tf.nn.sigmoid(logits)return out,logits
epochs = 100
lr = 0.002
batch_size = 64
gen_szie = 100
with tf.name_scope('gen_inp') as scope:gen_inp = tf.placeholder(dtype=tf.float32,shape=[None,gen_szie],name='gen_inp')
with tf.name_scope('real_inp') as scope:real_inp = tf.placeholder(dtype=tf.float32,shape=[None,784],name='real_inp')gen_out = generator(gen_inp,'generator',reuse=False)real_out,real_logits = discriminator(real_inp,name='discriminator',alpha=0.01,reuse=False)
fake_out,fake_logits = discriminator(gen_out,name='discriminator',alpha=0.01,reuse=True)with tf.name_scope('metrics') as scope:loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logits),logits=fake_logits))loss_d_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_logits),logits=fake_logits))loss_d_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logits)*0.99,logits=real_logits))loss_d = loss_d_g + loss_d_realvar_list_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator')var_list_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')g_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_g,var_list=var_list_g)d_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_d, var_list=var_list_d)
sum_g = tf.summary.scalar('g_loss',loss_g)
sum_d = tf.summary.scalar('g_loss',loss_g)
mer_g = tf.summary.merge([sum_g])
mer_d = tf.summary.merge([sum_d])
with tf.Session() as sess:sess.run(tf.global_variables_initializer())writer = tf.summary.FileWriter('./graph/mnist',sess.graph)saver = tf.train.Saver()n_batchs = mnist.train.num_examples // batch_sizefor epoch in range(epochs):total_loss_d = 0total_loss_g = 0for ii in range(n_batchs):xs_real,ys = mnist.train.next_batch(batch_size)xs_real = xs_real*2 - 1xs_gen = np.random.uniform(-1,1,[batch_size,gen_szie])_,train_loss_d,summ_d = sess.run([d_optimizer,loss_d,mer_d],feed_dict={gen_inp:xs_gen,real_inp:xs_real})writer.add_summary(summ_d)_, train_loss_g,summ_g = sess.run([g_optimizer, loss_g,mer_g], feed_dict={gen_inp: xs_gen, real_inp: xs_real})writer.add_summary(summ_g)total_loss_d += train_loss_dtotal_loss_g += train_loss_gif epoch % 10 == 0:print('epoch {},loss_g={}'.format(epoch,total_loss_g/n_batchs))print('epoch {},loss_d={}'.format(epoch, total_loss_d/n_batchs))xs_gen = np.random.uniform(-1, 1, [batch_size, gen_szie])gen_img = sess.run(gen_out,feed_dict={gen_inp:xs_gen})vis_img(-1,[gen_img])writer.close()saver.save(sess, "./checkpoints/mnist")
然后我们看一下效果:
可以看出效果还可以。
另外,我还实验了,就是我代码注释部分,generator里面,我使用bn层和relu层,发现效果一点也不好。一直是一堆麻子。
然后我在使用bn层加Leaky ReLU,效果也很好。
最后我有把bn层去掉,感觉影响不是很大,效果还可以。
GAN生成手写字体识别相关推荐
- 人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)
人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架) 直接上图,项目效果 1 ...
- pytorch CNN手写字体识别
## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...
- 第六讲 Keras实现手写字体识别分类
一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...
- Android Studio编写一个手写字体识别程序
1.activity_main.xml 的代码 <?xml version="1.0" encoding="utf-8"?> <LinearL ...
- python手写字体程序_深度学习---手写字体识别程序分析(python)
我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...
- pytorch rnn 实现手写字体识别
pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...
- 《MATLAB 神经网络43个案例分析》:第19章 基于SVM的手写字体识别
<MATLAB 神经网络43个案例分析>:第19章 基于SVM的手写字体识别 1. 前言 2. MATLAB 仿真示例 3. 小结 1. 前言 <MATLAB 神经网络43个案例分析 ...
- 手写字体识别 --MNIST数据集
Matlab 手写字体识别 忙过这段时间后,对于上次读取的Matlab内部数据实现的识别,我回味了一番,觉得那个实在太小.所以打算把数据换成[MNIST数据集][1]. 基础思想还是相同的,使用Tre ...
- 神经网络学习(二)Tensorflow-简单神经网络(全连接层神经网络)实现手写字体识别
神经网络学习(二)神经网络-手写字体识别 框架:Tensorflow 1.10.0 数据集:mnist数据集 策略:交叉熵损失 优化:梯度下降 五个模块:拿数据.搭网络.求损失.优化损失.算准确率 一 ...
- MNIST手写字体识别入门编译过程遇到的问题及解决
MNIST手写字体识别入门编译过程遇到的问题及解决 以MNIST手写字体识别作为神经网络及各种网络模型的作为练手,将遇到的问题在这里记录与交流. 激活tensorflow环境后,运行spyder或者j ...
最新文章
- Pycharm中tensorflow框架下tqdm的安装
- 了解了解一下SQLSERVER里的鬼影记录
- DokiCam 360°4K相机:为极致运动爱好者而生
- HTTP 错误 404.3 – Not Found 由于扩展配置问题而无法提供您请求的页面。如果该页面是脚本,请添加处理程序...
- 两个inline-block消除间距和对齐(vertical-align)
- SAP电商云CCV2 Restful API enablement
- 使用Def文件导出dll
- js获取屏幕宽高和下拉加载更多
- ASP.NETLinkButton的Click事件中获取CommandArgument的值
- 【EMNLP2020】超越MLM,微软打造全新预训练任务
- Day16:C++之STL应用篇(推箱子cxk限定)
- uniapp 蓝牙通讯(搜索/连接蓝牙、读、写)
- 【Houdini MAYA】从MAYA到Houdini入门学习笔记(三)
- 电脑网络通过usb分享给手机
- 仿照苏宁易购小程序页面
- python 最速曲线
- 华为机试(JAVA)真题Od【A卷+B卷】
- 翡翠手链更能够突显佩戴者的非凡气质
- c语言说明函数的作用是,C语言中rewind函数的作用是什么?
- linux双网卡双路由配置,linux配置双网卡双路由