GAN网络生成手写体数字图片
Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的。
目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接层,卷积层,池化层等等。对于需要对网络本身做创新的实验,keas可能不是很方便,还是得用tensorflow来搭建。
这篇博客,我想用Keras写一个简单的生成对抗网络。
生成对抗网络的目标是生成手写体数字。
先看看实验的效果:
epoch=1000的时候:
epoch=10000的时候:数字1已经有点像了
epoch=60000,数字1就很清晰了 ,而且其他数字也越来越清晰了
epoch=80000: 生成了5,7 啥的了。
随着训练的加深,生成的数字会越来越真实了。
代码已经开源,项目地址:
https://github.com/jmhIcoding/GAN_MNIST.git
模型原理
模型原理就不说了,就是使用最基础GAN结构。
模型由一个生成器和一个鉴别器组成。
生成器用于输入噪声,然后生成一个手写体数字图片。
鉴别器用于判断某个输入给它的图片是不是生成器合成的。
生成器的目标是生成让鉴别器判断为非合成的图片。
鉴别器的目标则是以尽量高的正确率分类某种图片是否为合成的。
总的原理就是这些了。
模型的损失函数就是围绕着这两个目标来展开的。
模型编写
生成器
__author__ = 'dk'
#生成器import sys
import numpy as npimport keras
from keras import layers
from keras import models
from keras import optimizers
from keras import lossesclass Generator:def __init__(self,height=28,width=28,channel=1,latent_space_dimension=100):''':param height: 生成图片的高,minist为28:param width: 生成图片的宽,minist为28:param channel: 生成器所生成的图片的通道数目,对于mnist灰度图来说,channel为1:param latent_space_dimension: 噪声的维度:return:'''self.latent_space_dimension = latent_space_dimensionself.height = heightself.width = widthself.channel = channelself.generator = self.build_model()self.generator.summary()def build_model(self,block_starting_size=128,num_blocks=4):model = models.Sequential(name='generator')for i in range(num_blocks):if i ==0 :model.add(layers.Dense(block_starting_size,input_shape=(self.latent_space_dimension,)))else:block_size = block_starting_size * (2**i)model.add(layers.Dense(block_size))model.add(layers.LeakyReLU())model.add(layers.BatchNormalization(momentum=0.75))model.add(layers.Dense(self.height*self.channel*self.width,activation='tanh'))model.add(layers.Reshape((self.width,self.height,self.channel)))return modeldef summary(self):self.model.summary()def save_model(self):self.generator.save("generator.h5")
注意,generator是和整个模型一起训练的,它可以不需要compile模型。
鉴别器
__author__ = 'dk'
#判别器
import sys
import os
import keras
from keras import layers
from keras import optimizers
from keras import models
from keras import losses
class Discriminator:def __init__(self,height=28,width=28,channel=1):''':param height: 输入图片的高:param width: 输入图片的宽:param channel: 输入图片的通道数:return:'''self.height = heightself.width = widthself.channel = channelself.discriminator = self.build_model()OPTIMIZER = optimizers.Adam()self.discriminator = self.build_model()self.discriminator.compile(optimizer=OPTIMIZER,loss=losses.binary_crossentropy,metrics =['accuracy'])self.discriminator.summary()def build_model(self):model = models.Sequential(name='discriminator')model.add(layers.Flatten(input_shape=(self.width,self.height,self.channel)))model.add(layers.Dense(self.height*self.width*self.channel,input_shape=(self.width,self.height,self.channel)))model.add(layers.LeakyReLU(0.2))model.add(layers.Dense(self.height*self.width*self.channel//2))model.add(layers.LeakyReLU(0.2))model.add(layers.Dense(1,activation='sigmoid'))return modeldef summary(self):return self.discriminator.summary()def save_model(self):self.discriminator.save("discriminator.h5")
gan网络
把生成器和鉴别器合并起来
__author__ = 'dk'
#生成对抗网络import keras
from keras import layers
from keras import optimizers
from keras import losses
from keras import modelsimport sys
import osfrom Discriminator import Discriminator
from Generator import Generator
class GAN:def __init__(self,latent_space_dimension,height,width,channel):self.generator = Generator(height,width,channel,latent_space_dimension)self.discriminator = Discriminator(height,width,channel)self.discriminator.discriminator.trainable = False #gan部分,只训练生成器,鉴别器通过显式discriminator.train_on_batch调用来训练self.gan = self.build_model()OPTIMIZER = optimizers.Adamax()self.gan.compile(optimizer = OPTIMIZER,loss = losses.binary_crossentropy)self.gan.summary()def build_model(self):model = models.Sequential(name='gan')model.add(self.generator.generator)model.add(self.discriminator.discriminator)return modeldef summary(self):self.gan.summary()def save_model(self):self.gan.save("gan.h5")
数据准备模块
__author__ = 'dk'
#数据集采集器,主要是对mnist进行简单的封装
from keras.datasets import mnist
import numpy as np
def sample_latent_space(instances_number,latent_space_dimension):return np.random.normal(0,1,(instances_number,latent_space_dimension))class Dator:def __init__(self,batch_size=None,model_type=1):''':param batch_size::param model_type: 当model_type为-1的时候,表示0-9个数字都选;当model_type=2,说明只选择数字2:return:'''self.batch_size = batch_sizeself.model_type = model_typewith np.load("mnist.npz", allow_pickle=True) as f:X_train, y_train = f['x_train'], f['y_train']#X_test, y_test = f['x_test'], f['y_test']if model_type != -1:X_train = X_train[np.where(y_train==model_type)[0]]if batch_size == None:self.batch_size = X_train.shape[0]else:self.batch_size = batch_sizeself.X_train = (np.float32(X_train)-128)/128.0self.X_train = np.expand_dims(self.X_train,3)self.watch_index = 0self.train_size = self.X_train.shape[0]def next_batch(self,batch_size = None):if batch_size == None:batch_size =self.batch_sizeX=np.concatenate([self.X_train[self.watch_index:(self.watch_index+batch_size)], self.X_train[:batch_size]])[:batch_size]self.watch_index = (self.watch_index + batch_size) % self.train_sizereturn Xif __name__ == '__main__':print(sample_latent_space(5,4))
训练main脚本:train.py
__author__ = 'dk'
#模型训练代码
from GAN import GAN
from data_utils import Dator,sample_latent_space
import numpy as np
from matplotlib import pyplot as plt
import timeepochs = 50000
height = 28
width = 28
channel =1
latent_space_dimension = 100
batch = 128
dator = Dator(batch_size=batch,model_type=-1)
gan = GAN(latent_space_dimension,height,width,channel)
image_index = 0
for i in range(epochs):real_img = dator.next_batch(batch_size=batch*2)real_label = np.ones(shape=(real_img.shape[0],1)) #真实的样本设置为1的标签noise = sample_latent_space(real_img.shape[0],latent_space_dimension)fake_img = gan.generator.generator.predict(noise)fake_label = np.zeros(shape=(fake_img.shape[0],1)) #生成器生成的假图片标注为0###合成给gan的鉴别器的数据x_batch = np.concatenate([real_img,fake_img])y_batch = np.concatenate([real_label,fake_label])#训练一次discriminator_loss = gan.discriminator.discriminator.train_on_batch(x_batch,y_batch)[0]###注意,此时训练的是鉴别器,生成器部分不动。###合成训练生成器的数据noise = sample_latent_space(batch*2,latent_space_dimension)noise_labels = np.ones((batch*2,1)) #生成器的目标是把图片的label越来越像1generator_loss = gan.gan.train_on_batch(noise,noise_labels)print('Epoch : {0}, [Discriminator Loss:{1} ], [Generator Loss:{2}]'.format(i,discriminator_loss,generator_loss))if i!=0 and (i%50)==0:print('show time')#每50次输入16张图片看看效果noise = sample_latent_space(16,latent_space_dimension)images = gan.generator.generator.predict(noise)plt.figure(figsize=(10,10))plt.suptitle('epoch={0}'.format(i),fontsize=16)for index in range(images.shape[0]):plt.subplot(4,4,index+1)image =images[index,:,:,:]image = image.reshape(height,width)plt.imshow(image,cmap='gray')#plt.tight_layout()plt.savefig("./show_time/{0}.png".format(time.time()))image_index += 1plt.close()
运行脚本
python3 train.py
即可。
输出:
Model: "generator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 128) 12928
_________________________________________________________________
dense_2 (Dense) (None, 256) 33024
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 256) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256) 1024
_________________________________________________________________
dense_3 (Dense) (None, 512) 131584
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 512) 0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512) 2048
_________________________________________________________________
dense_4 (Dense) (None, 1024) 525312
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024) 4096
_________________________________________________________________
dense_5 (Dense) (None, 784) 803600
_________________________________________________________________
reshape_1 (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 1,513,616
Trainable params: 1,510,032
Non-trainable params: 3,584
_________________________________________________________________
Model: "discriminator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_2 (Flatten) (None, 784) 0
_________________________________________________________________
dense_9 (Dense) (None, 784) 615440
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 784) 0
_________________________________________________________________
dense_10 (Dense) (None, 392) 307720
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 392) 0
_________________________________________________________________
dense_11 (Dense) (None, 1) 393
=================================================================
Total params: 923,553
Trainable params: 923,553
Non-trainable params: 0
_________________________________________________________________
Model: "gan"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
generator (Sequential) (None, 28, 28, 1) 1513616
_________________________________________________________________
discriminator (Sequential) (None, 1) 923553
=================================================================
Total params: 2,437,169
Trainable params: 1,510,032
Non-trainable params: 927,137
_________________________________________________________________
····
···
··Epoch : 117754, [Discriminator Loss:0.22975191473960876 ], [Generator Loss:2.57688570022583]
Epoch : 117755, [Discriminator Loss:0.26782122254371643 ], [Generator Loss:3.1791584491729736]
Epoch : 117756, [Discriminator Loss:0.2609345614910126 ], [Generator Loss:2.960988998413086]
Epoch : 117757, [Discriminator Loss:0.2673880159854889 ], [Generator Loss:2.317220687866211]
Epoch : 117758, [Discriminator Loss:0.24904575943946838 ], [Generator Loss:1.929720401763916]
Epoch : 117759, [Discriminator Loss:0.25158950686454773 ], [Generator Loss:2.954155683517456]
Epoch : 117760, [Discriminator Loss:0.20324105024337769 ], [Generator Loss:3.5244760513305664]
Epoch : 117761, [Discriminator Loss:0.2849388122558594 ], [Generator Loss:3.195873498916626]
Epoch : 117762, [Discriminator Loss:0.19631560146808624 ], [Generator Loss:2.328411340713501]
Epoch : 117763, [Discriminator Loss:0.20523831248283386 ], [Generator Loss:2.402683973312378]
Epoch : 117764, [Discriminator Loss:0.2625979781150818 ], [Generator Loss:3.2176101207733154]
Epoch : 117765, [Discriminator Loss:0.29969191551208496 ], [Generator Loss:2.9656052589416504]
Epoch : 117766, [Discriminator Loss:0.270328551530838 ], [Generator Loss:2.3880398273468018]
Epoch : 117767, [Discriminator Loss:0.26741161942481995 ], [Generator Loss:2.7729406356811523]
Epoch : 117768, [Discriminator Loss:0.28797847032546997 ], [Generator Loss:2.8959264755249023]
Epoch : 117769, [Discriminator Loss:0.30181047320365906 ], [Generator Loss:2.791097402572632]
Epoch : 117770, [Discriminator Loss:0.26939862966537476 ], [Generator Loss:2.3666043281555176]
Epoch : 117771, [Discriminator Loss:0.26297527551651 ], [Generator Loss:2.895970582962036]
Epoch : 117772, [Discriminator Loss:0.21928083896636963 ], [Generator Loss:3.4627976417541504]
Epoch : 117773, [Discriminator Loss:0.3553962707519531 ], [Generator Loss:3.2194197177886963]
Epoch : 117774, [Discriminator Loss:0.32673510909080505 ], [Generator Loss:2.473867893218994]
Epoch : 117775, [Discriminator Loss:0.31245478987693787 ], [Generator Loss:2.999265193939209]
Epoch : 117776, [Discriminator Loss:0.29536381363868713 ], [Generator Loss:3.733344554901123]
Epoch : 117777, [Discriminator Loss:0.2955515682697296 ], [Generator Loss:3.2467658519744873]
Epoch : 117778, [Discriminator Loss:0.3677394986152649 ], [Generator Loss:1.8517814874649048]
Epoch : 117779, [Discriminator Loss:0.31648850440979004 ], [Generator Loss:2.6385254859924316]
Epoch : 117780, [Discriminator Loss:0.31941041350364685 ], [Generator Loss:3.350475311279297]
Epoch : 117781, [Discriminator Loss:0.47521263360977173 ], [Generator Loss:1.9556307792663574]
Epoch : 117782, [Discriminator Loss:0.44070643186569214 ], [Generator Loss:1.9684114456176758]
GAN网络生成手写体数字图片相关推荐
- 不服就GAN:GAN网络生成 cifar10 的图片实例(keras 详细实现步骤),GAN 的训练的各种技巧总结,GAN的注意事项和大坑汇总
GAN 的调参技巧总结 生成器的最后一层不使用 sigmoid,使用 tanh 代替 使用噪声作为生成器的输入时,生成噪声的步骤使用 正态分布 的采样来产生,而不使用均匀分布 训练 discrimin ...
- 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)
图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow) 文章目录 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网 ...
- GAN (生成对抗网络) 手写数字图片生成
GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...
- tensorflow学习笔记(十):GAN生成手写体数字(MNIST)
文章目录 一.GAN原理 二.项目实战 2.1 项目背景 2.2 网络描述 2.3 项目实战 一.GAN原理 生成对抗网络简称GAN,是由两个网络组成的,一个生成器网络和一个判别器网络.这两个网络可以 ...
- python制作图片数据集,Python 3 生成手写体数字数据集
0.引言 平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集: 自己尝试写了一个生成手写体图片的python程序,在此分 ...
- python数字1 3怎么表示_Python3生成手写体数字方法
0.引言 平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集: 自己尝试写了一个生成手写体图片的python程序,在此分 ...
- python写数字,Python3生成手写体数字方法
0.引言 平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集: 自己尝试写了一个生成手写体图片的python程序,在此分 ...
- 如何从TensorFlow的mnist数据集导出手写体数字图片
在TensorFlow的官方入门课程中,多次用到mnist数据集. mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二 ...
- 基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字
目录 0.前言 一.GAN结构 二.函数代码 2.1 生成器Generator 2.2 判别器Discriminator 2.3 train函数 三.结果演示 四.完整代码 五.常见问题汇总 0.前言 ...
最新文章
- openresty编译添加stream-lua-nginx-module模块
- python input与返回值-Python 详解基本语法_函数_返回值
- Orecle基本概述(2)
- Windows 32位程序在64位操作系统下运行
- Solr配置IK分词器
- 使用Memcached提高.NET应用程序的性能
- SecureCRT配置proxy连接云主机
- SQL server插入数据后,获取自增长字段的值
- vb.net 数据集设计器 新增列_SQLPro for MSSQL for Mac(数据库客户端)
- jQueryUI modal dialog does not show close button (x) JQueryUI和BootStrap混用时候,右上角关闭按钮显示不出图标的解决办法...
- 苹果全面开放漏洞奖励计划:最高100万美元等你拿
- 用隐喻来更充分地理解软件开发
- linux内核之数据机构
- 2016-408-计组-有如下c语言程序段
- 【洛谷试炼场】普及练习场——贪心
- iOS 获取APP名称 版本等
- 【数据压缩】使用Audacity软件分析浊音、清音爆破音的时域及频域特性。
- 《人生的智慧》-叔本华著[韦启昌-(译)]
- 【电源专题】开关电源的控制器和稳压器的区别
- newman的基本使用