Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的。
目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接层,卷积层,池化层等等。对于需要对网络本身做创新的实验,keas可能不是很方便,还是得用tensorflow来搭建。

这篇博客,我想用Keras写一个简单的生成对抗网络。
生成对抗网络的目标是生成手写体数字。

先看看实验的效果:
epoch=1000的时候:

epoch=10000的时候:数字1已经有点像了

epoch=60000,数字1就很清晰了 ,而且其他数字也越来越清晰了

epoch=80000: 生成了5,7 啥的了。

随着训练的加深,生成的数字会越来越真实了。
代码已经开源,项目地址:

https://github.com/jmhIcoding/GAN_MNIST.git

模型原理

模型原理就不说了,就是使用最基础GAN结构。
模型由一个生成器和一个鉴别器组成。
生成器用于输入噪声,然后生成一个手写体数字图片。
鉴别器用于判断某个输入给它的图片是不是生成器合成的。

生成器的目标是生成让鉴别器判断为非合成的图片。
鉴别器的目标则是以尽量高的正确率分类某种图片是否为合成的。

总的原理就是这些了。
模型的损失函数就是围绕着这两个目标来展开的。

模型编写

生成器

__author__ = 'dk'
#生成器import sys
import numpy as npimport  keras
from  keras import layers
from keras import models
from  keras import optimizers
from keras import lossesclass Generator:def __init__(self,height=28,width=28,channel=1,latent_space_dimension=100):''':param height:    生成图片的高,minist为28:param width:     生成图片的宽,minist为28:param channel:   生成器所生成的图片的通道数目,对于mnist灰度图来说,channel为1:param latent_space_dimension:  噪声的维度:return:'''self.latent_space_dimension = latent_space_dimensionself.height = heightself.width = widthself.channel = channelself.generator = self.build_model()self.generator.summary()def build_model(self,block_starting_size=128,num_blocks=4):model = models.Sequential(name='generator')for i in range(num_blocks):if i ==0 :model.add(layers.Dense(block_starting_size,input_shape=(self.latent_space_dimension,)))else:block_size = block_starting_size * (2**i)model.add(layers.Dense(block_size))model.add(layers.LeakyReLU())model.add(layers.BatchNormalization(momentum=0.75))model.add(layers.Dense(self.height*self.channel*self.width,activation='tanh'))model.add(layers.Reshape((self.width,self.height,self.channel)))return  modeldef summary(self):self.model.summary()def save_model(self):self.generator.save("generator.h5")

注意,generator是和整个模型一起训练的,它可以不需要compile模型。

鉴别器

__author__ = 'dk'
#判别器
import sys
import os
import keras
from  keras import layers
from keras import optimizers
from keras import models
from keras import losses
class Discriminator:def __init__(self,height=28,width=28,channel=1):''':param height:  输入图片的高:param width:   输入图片的宽:param channel: 输入图片的通道数:return:'''self.height = heightself.width = widthself.channel = channelself.discriminator = self.build_model()OPTIMIZER = optimizers.Adam()self.discriminator = self.build_model()self.discriminator.compile(optimizer=OPTIMIZER,loss=losses.binary_crossentropy,metrics =['accuracy'])self.discriminator.summary()def build_model(self):model = models.Sequential(name='discriminator')model.add(layers.Flatten(input_shape=(self.width,self.height,self.channel)))model.add(layers.Dense(self.height*self.width*self.channel,input_shape=(self.width,self.height,self.channel)))model.add(layers.LeakyReLU(0.2))model.add(layers.Dense(self.height*self.width*self.channel//2))model.add(layers.LeakyReLU(0.2))model.add(layers.Dense(1,activation='sigmoid'))return modeldef summary(self):return self.discriminator.summary()def save_model(self):self.discriminator.save("discriminator.h5")

gan网络

把生成器和鉴别器合并起来

__author__ = 'dk'
#生成对抗网络import keras
from keras import layers
from  keras import optimizers
from  keras import  losses
from  keras import modelsimport  sys
import osfrom Discriminator import Discriminator
from Generator import Generator
class GAN:def __init__(self,latent_space_dimension,height,width,channel):self.generator  = Generator(height,width,channel,latent_space_dimension)self.discriminator = Discriminator(height,width,channel)self.discriminator.discriminator.trainable = False #gan部分,只训练生成器,鉴别器通过显式discriminator.train_on_batch调用来训练self.gan =  self.build_model()OPTIMIZER = optimizers.Adamax()self.gan.compile(optimizer = OPTIMIZER,loss = losses.binary_crossentropy)self.gan.summary()def build_model(self):model  = models.Sequential(name='gan')model.add(self.generator.generator)model.add(self.discriminator.discriminator)return  modeldef summary(self):self.gan.summary()def save_model(self):self.gan.save("gan.h5")

数据准备模块

__author__ = 'dk'
#数据集采集器,主要是对mnist进行简单的封装
from keras.datasets import mnist
import numpy as np
def sample_latent_space(instances_number,latent_space_dimension):return  np.random.normal(0,1,(instances_number,latent_space_dimension))class Dator:def __init__(self,batch_size=None,model_type=1):''':param batch_size::param model_type:  当model_type为-1的时候,表示0-9个数字都选;当model_type=2,说明只选择数字2:return:'''self.batch_size = batch_sizeself.model_type = model_typewith np.load("mnist.npz", allow_pickle=True) as f:X_train, y_train = f['x_train'], f['y_train']#X_test, y_test = f['x_test'], f['y_test']if model_type != -1:X_train = X_train[np.where(y_train==model_type)[0]]if batch_size == None:self.batch_size = X_train.shape[0]else:self.batch_size = batch_sizeself.X_train = (np.float32(X_train)-128)/128.0self.X_train = np.expand_dims(self.X_train,3)self.watch_index = 0self.train_size = self.X_train.shape[0]def next_batch(self,batch_size = None):if batch_size == None:batch_size  =self.batch_sizeX=np.concatenate([self.X_train[self.watch_index:(self.watch_index+batch_size)], self.X_train[:batch_size]])[:batch_size]self.watch_index  = (self.watch_index + batch_size) % self.train_sizereturn  Xif __name__ == '__main__':print(sample_latent_space(5,4))

训练main脚本:train.py

__author__ = 'dk'
#模型训练代码
from  GAN import GAN
from data_utils import Dator,sample_latent_space
import  numpy as np
from matplotlib import pyplot as plt
import timeepochs = 50000
height = 28
width = 28
channel =1
latent_space_dimension = 100
batch = 128
dator = Dator(batch_size=batch,model_type=-1)
gan = GAN(latent_space_dimension,height,width,channel)
image_index = 0
for i in range(epochs):real_img = dator.next_batch(batch_size=batch*2)real_label = np.ones(shape=(real_img.shape[0],1))       #真实的样本设置为1的标签noise = sample_latent_space(real_img.shape[0],latent_space_dimension)fake_img = gan.generator.generator.predict(noise)fake_label = np.zeros(shape=(fake_img.shape[0],1))      #生成器生成的假图片标注为0###合成给gan的鉴别器的数据x_batch = np.concatenate([real_img,fake_img])y_batch = np.concatenate([real_label,fake_label])#训练一次discriminator_loss = gan.discriminator.discriminator.train_on_batch(x_batch,y_batch)[0]###注意,此时训练的是鉴别器,生成器部分不动。###合成训练生成器的数据noise = sample_latent_space(batch*2,latent_space_dimension)noise_labels = np.ones((batch*2,1))           #生成器的目标是把图片的label越来越像1generator_loss = gan.gan.train_on_batch(noise,noise_labels)print('Epoch : {0}, [Discriminator Loss:{1} ], [Generator Loss:{2}]'.format(i,discriminator_loss,generator_loss))if i!=0 and (i%50)==0:print('show time')#每50次输入16张图片看看效果noise = sample_latent_space(16,latent_space_dimension)images = gan.generator.generator.predict(noise)plt.figure(figsize=(10,10))plt.suptitle('epoch={0}'.format(i),fontsize=16)for index in range(images.shape[0]):plt.subplot(4,4,index+1)image  =images[index,:,:,:]image = image.reshape(height,width)plt.imshow(image,cmap='gray')#plt.tight_layout()plt.savefig("./show_time/{0}.png".format(time.time()))image_index += 1plt.close()

运行脚本

python3 train.py

即可。
输出:

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_1 (Dense)              (None, 128)               12928
_________________________________________________________________
dense_2 (Dense)              (None, 256)               33024
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256)               0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024
_________________________________________________________________
dense_3 (Dense)              (None, 512)               131584
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 512)               0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              525312
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 1024)              0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096
_________________________________________________________________
dense_5 (Dense)              (None, 784)               803600
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0
=================================================================
Total params: 1,513,616
Trainable params: 1,510,032
Non-trainable params: 3,584
_________________________________________________________________
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_2 (Flatten)          (None, 784)               0
_________________________________________________________________
dense_9 (Dense)              (None, 784)               615440
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 784)               0
_________________________________________________________________
dense_10 (Dense)             (None, 392)               307720
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 392)               0
_________________________________________________________________
dense_11 (Dense)             (None, 1)                 393
=================================================================
Total params: 923,553
Trainable params: 923,553
Non-trainable params: 0
_________________________________________________________________
Model: "gan"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
generator (Sequential)       (None, 28, 28, 1)         1513616
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 923553
=================================================================
Total params: 2,437,169
Trainable params: 1,510,032
Non-trainable params: 927,137
_________________________________________________________________
····
···
··Epoch : 117754, [Discriminator Loss:0.22975191473960876 ], [Generator Loss:2.57688570022583]
Epoch : 117755, [Discriminator Loss:0.26782122254371643 ], [Generator Loss:3.1791584491729736]
Epoch : 117756, [Discriminator Loss:0.2609345614910126 ], [Generator Loss:2.960988998413086]
Epoch : 117757, [Discriminator Loss:0.2673880159854889 ], [Generator Loss:2.317220687866211]
Epoch : 117758, [Discriminator Loss:0.24904575943946838 ], [Generator Loss:1.929720401763916]
Epoch : 117759, [Discriminator Loss:0.25158950686454773 ], [Generator Loss:2.954155683517456]
Epoch : 117760, [Discriminator Loss:0.20324105024337769 ], [Generator Loss:3.5244760513305664]
Epoch : 117761, [Discriminator Loss:0.2849388122558594 ], [Generator Loss:3.195873498916626]
Epoch : 117762, [Discriminator Loss:0.19631560146808624 ], [Generator Loss:2.328411340713501]
Epoch : 117763, [Discriminator Loss:0.20523831248283386 ], [Generator Loss:2.402683973312378]
Epoch : 117764, [Discriminator Loss:0.2625979781150818 ], [Generator Loss:3.2176101207733154]
Epoch : 117765, [Discriminator Loss:0.29969191551208496 ], [Generator Loss:2.9656052589416504]
Epoch : 117766, [Discriminator Loss:0.270328551530838 ], [Generator Loss:2.3880398273468018]
Epoch : 117767, [Discriminator Loss:0.26741161942481995 ], [Generator Loss:2.7729406356811523]
Epoch : 117768, [Discriminator Loss:0.28797847032546997 ], [Generator Loss:2.8959264755249023]
Epoch : 117769, [Discriminator Loss:0.30181047320365906 ], [Generator Loss:2.791097402572632]
Epoch : 117770, [Discriminator Loss:0.26939862966537476 ], [Generator Loss:2.3666043281555176]
Epoch : 117771, [Discriminator Loss:0.26297527551651 ], [Generator Loss:2.895970582962036]
Epoch : 117772, [Discriminator Loss:0.21928083896636963 ], [Generator Loss:3.4627976417541504]
Epoch : 117773, [Discriminator Loss:0.3553962707519531 ], [Generator Loss:3.2194197177886963]
Epoch : 117774, [Discriminator Loss:0.32673510909080505 ], [Generator Loss:2.473867893218994]
Epoch : 117775, [Discriminator Loss:0.31245478987693787 ], [Generator Loss:2.999265193939209]
Epoch : 117776, [Discriminator Loss:0.29536381363868713 ], [Generator Loss:3.733344554901123]
Epoch : 117777, [Discriminator Loss:0.2955515682697296 ], [Generator Loss:3.2467658519744873]
Epoch : 117778, [Discriminator Loss:0.3677394986152649 ], [Generator Loss:1.8517814874649048]
Epoch : 117779, [Discriminator Loss:0.31648850440979004 ], [Generator Loss:2.6385254859924316]
Epoch : 117780, [Discriminator Loss:0.31941041350364685 ], [Generator Loss:3.350475311279297]
Epoch : 117781, [Discriminator Loss:0.47521263360977173 ], [Generator Loss:1.9556307792663574]
Epoch : 117782, [Discriminator Loss:0.44070643186569214 ], [Generator Loss:1.9684114456176758]

GAN网络生成手写体数字图片相关推荐

  1. 不服就GAN:GAN网络生成 cifar10 的图片实例(keras 详细实现步骤),GAN 的训练的各种技巧总结,GAN的注意事项和大坑汇总

    GAN 的调参技巧总结 生成器的最后一层不使用 sigmoid,使用 tanh 代替 使用噪声作为生成器的输入时,生成噪声的步骤使用 正态分布 的采样来产生,而不使用均匀分布 训练 discrimin ...

  2. 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

    图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow) 文章目录 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网 ...

  3. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

  4. tensorflow学习笔记(十):GAN生成手写体数字(MNIST)

    文章目录 一.GAN原理 二.项目实战 2.1 项目背景 2.2 网络描述 2.3 项目实战 一.GAN原理 生成对抗网络简称GAN,是由两个网络组成的,一个生成器网络和一个判别器网络.这两个网络可以 ...

  5. python制作图片数据集,Python 3 生成手写体数字数据集

    0.引言 平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集: 自己尝试写了一个生成手写体图片的python程序,在此分 ...

  6. python数字1 3怎么表示_Python3生成手写体数字方法

    0.引言 平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集: 自己尝试写了一个生成手写体图片的python程序,在此分 ...

  7. python写数字,Python3生成手写体数字方法

    0.引言 平时上网干啥的基本上都会接触验证码,或者在机器学习学习过程中,大家或许会接触过手写体识别/验证码识别之类问题,会用到手写体的数据集: 自己尝试写了一个生成手写体图片的python程序,在此分 ...

  8. 如何从TensorFlow的mnist数据集导出手写体数字图片

    在TensorFlow的官方入门课程中,多次用到mnist数据集. mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二 ...

  9. 基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字

    目录 0.前言 一.GAN结构 二.函数代码 2.1 生成器Generator 2.2 判别器Discriminator 2.3 train函数 三.结果演示 四.完整代码 五.常见问题汇总 0.前言 ...

最新文章

  1. openresty编译添加stream-lua-nginx-module模块
  2. python input与返回值-Python 详解基本语法_函数_返回值
  3. Orecle基本概述(2)
  4. Windows 32位程序在64位操作系统下运行
  5. Solr配置IK分词器
  6. 使用Memcached提高.NET应用程序的性能
  7. SecureCRT配置proxy连接云主机
  8. SQL server插入数据后,获取自增长字段的值
  9. vb.net 数据集设计器 新增列_SQLPro for MSSQL for Mac(数据库客户端)
  10. jQueryUI modal dialog does not show close button (x) JQueryUI和BootStrap混用时候,右上角关闭按钮显示不出图标的解决办法...
  11. 苹果全面开放漏洞奖励计划:最高100万美元等你拿
  12. 用隐喻来更充分地理解软件开发
  13. linux内核之数据机构
  14. 2016-408-计组-有如下c语言程序段
  15. 【洛谷试炼场】普及练习场——贪心
  16. iOS 获取APP名称 版本等
  17. 【数据压缩】使用Audacity软件分析浊音、清音爆破音的时域及频域特性。
  18. 《人生的智慧》-叔本华著[韦启昌-(译)]
  19. 【电源专题】开关电源的控制器和稳压器的区别
  20. newman的基本使用

热门文章

  1. 品牌连锁企业如何突破技术壁垒对接分账系统?
  2. 进程同步C语言p v实验报告,操作系统实验报告模板
  3. Python函数和模块总结
  4. 芝法酱躺平攻略(4)—— powerdesigner与mybatis-plus生成代码
  5. 关于Oracle的参数是游标,如何处理(mirth)
  6. 碰到国内外虚拟机无法识别usb加密狗或者银行U盾问题,大家请进入!!!
  7. jsp和Java后台数据如何交互
  8. Ps学习(多边形套索工具使用)
  9. Unity客户端框架收集
  10. Android集成7z极限压缩