系列文章目录

深度学习GAN(一)之简单介绍
深度学习GAN(二)之基于CIFAR10数据集的例子
深度学习GAN(三)之基于手写体Mnist数据集的例子
深度学习GAN(四)之PIX2PIX GAN的例子


GAN基于手写体Mnist数据集生成新图片

  • 1. 代码运行结果
  • 2. GAN基于mnist数据集的完整代码

1. 代码运行结果

下图是GAN生成的手写体数字,用了10个epoch

2. GAN基于mnist数据集的完整代码

代码结构很像我的第二篇博客,如果你没看过,请先看那篇博客。里面有详细的代码讲解。

import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt# define the standalone discriminator model
def define_discriminator(in_shape=(28,28,1)):model = keras.models.Sequential()# normalmodel.add(keras.layers.Conv2D(64, (3,3), padding='same', input_shape=in_shape))model.add(keras.layers.LeakyReLU(alpha=0.2))# downsamplemodel.add(keras.layers.Conv2D(128, (3,3), strides=(2,2), padding='same'))model.add(keras.layers.LeakyReLU(alpha=0.2))# downsamplemodel.add(keras.layers.Conv2D(128, (3,3), strides=(2,2), padding='same'))model.add(keras.layers.LeakyReLU(alpha=0.2))# downsamplemodel.add(keras.layers.Conv2D(256, (3,3), strides=(2,2), padding='valid'))model.add(keras.layers.LeakyReLU(alpha=0.2))# classifiermodel.add(keras.layers.Flatten())model.add(keras.layers.Dropout(0.4))model.add(keras.layers.Dense(1, activation='sigmoid'))# compile modelopt = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])model.summary()return model# load and prepare cifar10 training images
def load_real_samples():# load cifar10 dataset(trainX, _), (_, _) = tf.keras.datasets.mnist.load_data()# convert from unsigned ints to floats#X = trainX.astype('float32')X = trainX.reshape(trainX.shape[0], 28, 28, 1).astype('float32')# scale from [0,255] to [-1,1]X = (X - 127.5) / 127.5return X# select real samples
def generate_real_samples(dataset, n_samples):# choose random instancesix = np.random.randint(0, dataset.shape[0], n_samples)# retrieve selected imagesX = dataset[ix]# generate 'real' class labels (1)y = np.ones((n_samples, 1))return X, ydef generate_fake_samples1(n_samples):# generate uniform random numbers in [0,1]X = np.random.rand(28 * 28 * 1 * n_samples)# update to have the range [-1, 1]X = -1 + X * 2# reshape into a batch of color imagesX = X.reshape((n_samples, 28, 28, 1))# generate 'fake' class labels (0)y = np.zeros((n_samples, 1))return X, y# train the discriminator model
def train_discriminator(model, dataset, n_iter=20, n_batch=128):half_batch = int(n_batch / 2)# manually enumerate epochsfor i in range(n_iter):# get randomly selected 'real' samplesX_real, y_real = generate_real_samples(dataset, half_batch)# update discriminator on real samples_, real_acc = model.train_on_batch(X_real, y_real)# generate 'fake' examplesX_fake, y_fake = generate_fake_samples1(half_batch)# update discriminator on fake samples_, fake_acc = model.train_on_batch(X_fake, y_fake)# summarize performanceprint('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_acc*100, fake_acc*100))def test_train_discriminator():# define the discriminator modelmodel = define_discriminator()# load image datadataset = load_real_samples()# fit the modeltrain_discriminator(model, dataset)# define the standalone generator model
def define_generator(latent_dim):model = keras.models.Sequential()# foundation for 4x4 imagen_nodes = 256 * 3 * 3model.add(keras.layers.Dense(n_nodes, input_dim=latent_dim))model.add(keras.layers.LeakyReLU(alpha=0.2))model.add(keras.layers.Reshape((3, 3, 256)))# upsample to 8x8model.add(keras.layers.Conv2DTranspose(128, (3,3), strides=(2,2), padding='valid'))model.add(keras.layers.LeakyReLU(alpha=0.2))# upsample to 16x16model.add(keras.layers.Conv2DTranspose(128, (3,3), strides=(2,2), padding='same'))model.add(keras.layers.LeakyReLU(alpha=0.2))# upsample to 32x32model.add(keras.layers.Conv2DTranspose(64, (3,3), strides=(2,2), padding='same'))model.add(keras.layers.LeakyReLU(alpha=0.2))# output layermodel.add(keras.layers.Conv2D(1, (3,3), activation='tanh', padding='same'))return model# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):# generate points in the latent spacex_input = np.random.randn(latent_dim * n_samples)# reshape into a batch of inputs for the networkx_input = x_input.reshape(n_samples, latent_dim)return x_input# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):# generate points in latent spacex_input = generate_latent_points(latent_dim, n_samples)# predict outputsX = g_model.predict(x_input)# create 'fake' class labels (0)y = np.zeros((n_samples, 1))return X, ydef show_fake_sample():# size of the latent spacelatent_dim = 100# define the discriminator modelmodel = define_generator(latent_dim)# generate samplesn_samples = 49X, _ = generate_fake_samples(model, latent_dim, n_samples)# scale pixel values from [-1,1] to [0,1]X = (X + 1) / 2.0# plot the generated samplesfor i in range(n_samples):# define subplotplt.subplot(7, 7, 1 + i)# turn off axis labelsplt.axis('off')# plot single imageplt.imshow(X[i])# show the figureplt.show()# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):# make weights in the discriminator not trainabled_model.trainable = False# connect themmodel = tf.keras.models.Sequential()# add generatormodel.add(g_model)# add the discriminatormodel.add(d_model)# compile modelopt = tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)model.compile(loss='binary_crossentropy', optimizer=opt)return modeldef show_gan_module():# size of the latent spacelatent_dim = 100# create the discriminatord_model = define_discriminator()# create the generatorg_model = define_generator(latent_dim)# create the gangan_model = define_gan(g_model, d_model)# summarize gan modelgan_model.summary()# train the composite model
def train_gan(gan_model, latent_dim, n_epochs=200, n_batch=128):# manually enumerate epochsfor i in range(n_epochs):# prepare points in latent space as input for the generatorx_gan = generate_latent_points(latent_dim, n_batch)# create inverted labels for the fake samplesy_gan = np.ones((n_batch, 1))# update the generator via the discriminator's errorgan_model.train_on_batch(x_gan, y_gan)# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=150):# prepare real samplesX_real, y_real = generate_real_samples(dataset, n_samples)# evaluate discriminator on real examples_, acc_real = d_model.evaluate(X_real, y_real, verbose=0)# prepare fake examplesx_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)# evaluate discriminator on fake examples_, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)# summarize discriminator performanceprint('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real * 100, acc_fake * 100))# save plot#save_plot(x_fake, epoch)# save the generator model tile filefilename = 'minst_generator_model_%03d.h5' % (epoch + 1)g_model.save(filename)# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=200, n_batch=128):bat_per_epo = int(dataset.shape[0] / n_batch)half_batch = int(n_batch / 2)# manually enumerate epochsfor i in range(n_epochs):# enumerate batches over the training setfor j in range(bat_per_epo):# get randomly selected 'real' samplesX_real, y_real = generate_real_samples(dataset, half_batch)# update discriminator model weightsd_loss1, _ = d_model.train_on_batch(X_real, y_real)# generate 'fake' examplesX_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)# update discriminator model weightsd_loss2, _ = d_model.train_on_batch(X_fake, y_fake)# prepare points in latent space as input for the generatorX_gan = generate_latent_points(latent_dim, n_batch)# create inverted labels for the fake samplesy_gan = np.ones((n_batch, 1))# update the generator via the discriminator's errorg_loss = gan_model.train_on_batch(X_gan, y_gan)# summarize loss on this batchprint('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %(i + 1, j + 1, bat_per_epo, d_loss1, d_loss2, g_loss))# evaluate the model performance, sometimesif (i + 1) % 10 == 0:summarize_performance(i, g_model, d_model, dataset, latent_dim)def test_train_gan():# size of the latent spacelatent_dim = 100# create the discriminatord_model = define_discriminator()# create the generatorg_model = define_generator(latent_dim)# create the gangan_model = define_gan(g_model, d_model)# load image datadataset = load_real_samples()# train modeltrain(g_model, d_model, gan_model, dataset, latent_dim)# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):# generate points in the latent spacex_input = np.random.randn(latent_dim * n_samples)# reshape into a batch of inputs for the networkx_input = x_input.reshape(n_samples, latent_dim)return x_input# plot the generated images
def create_plot(examples, n):# plot imagesfor i in range(n * n):# define subplotplt.subplot(n, n, 1 + i)# turn off axisplt.axis('off')# plot raw pixel dataplt.imshow(examples[i, :, :], cmap='gray')plt.show()def show_imgs_for_final_generator_model():# load modelmodel = tf.keras.models.load_model('minst_generator_model_010.h5')# generate imageslatent_points = generate_latent_points(100, 100)# generate imagesX = model.predict(latent_points)# scale from [-1,1] to [0,1]X = (X + 1) / 2.0# plot the resultX = X.reshape(X.shape[0], 28,28)create_plot(X, 10)def show_single_imgs():model = tf.keras.models.load_model('minst_generator_model_010.h5')# all 0svector = np.asarray([[0.75 for _ in range(100)]])# generate imageX = model.predict(vector)# scale from [-1,1] to [0,1]X = (X + 1) / 2.0# plot the resultplt.imshow(X[0, :, :])plt.show()if __name__ == '__main__':#define_discriminator()#test_train_discriminator()# show_fake_sample()#show_gan_module()test_train_gan()#g_module = define_generator(100)#print(g_module.summary())show_imgs_for_final_generator_model()# define the size of the latent space

[深度学习-实践]GAN基于手写体Mnist数据集生成新图片相关推荐

  1. [深度学习-实践]GAN入门例子-利用Tensorflow Keras与数据集CIFAR10生成新图片

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子; 深度学习GAN(三)之基于手写体Mnist数据集的例子; 深度学习GAN(四)之PIX2PIX G ...

  2. 深度学习入门源代码下载使用mnist数据集出现错误EOFError Compressed file ended before the end-of-stream marker was reached

    深度学习入门:基于Python的理论与实现源代码下载使用mnist数据集出现错误[EOFError: Compressed file ended before the end-of-stream ma ...

  3. TensorFlow:实战Google深度学习框架(四)MNIST数据集识别问题

    第5章 MNIST数字识别问题 5.1 MNIST数据处理 5.2 神经网络的训练以及不同模型结果的对比 5.2.1 TensorFlow训练神经网络 5.2.2 使用验证数据集判断模型的效果 5.2 ...

  4. 深度学习基础: BP神经网络训练MNIST数据集

    BP 神经网络训练MNIST数据集 不用任何深度学习框架,一起写一个神经网络训练MNIST数据集 本文试图让您通过手写一个简单的demo来讨论 1. 导包 import numpy as np imp ...

  5. 【深度学习实践】基于深度学习的车牌识别(python,车牌检测+车牌识别)

    车牌识别具有广泛的应用前景,基于传统方法的车牌识别效果一般比较差,随着计算机视觉技术的快速发展,深度学习的方法能够更好的完成车牌识别任务. 本文提供了车牌识别方案的部署链接,您可以在网页上体验该模型的 ...

  6. [深度学习-实践]tensorflow_hub简单理解模型的生成与加载

    0. 前言 Tensorflow于1.7之后推出了tensorflow hub,其是一个适合于迁移学习的部分,主要通过将tensorflow的训练好的模型进行模块划分,并可以再次加以利用.不过介于推出 ...

  7. [深度学习-实践]条件生成对抗网络cGAN的例子-Tensorflow2.x Keras

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子 深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子 深度学习GAN(四)之c ...

  8. [深度学习-实践]CycleGAN的入门例子-Tensorflow2.1-keras

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子 深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子 深度学习GAN(四)之c ...

  9. [深度学习-原理]GAN(生成对抗网络)的简单介绍

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子 深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子 深度学习GAN(四)之c ...

最新文章

  1. LeetCode 468 validate ip address(正则表达式)
  2. JS弹出窗口窗口的位置和大小
  3. 计算机视觉 | 计算机界国际学术会议和期刊目录
  4. Python3.x和Python2.x的区别[转]
  5. python迭代器好处_python迭代器
  6. mysql 文章内容_假设mysql数据库里面有个字段存的是文章内容,用什么方式查询出所有文章中包含某个特定词语的数据。...
  7. 【java】Java中TypeReference用法说明
  8. 另存为里面没有jpg_CAD图不会转JPG?教你两个方法,从此CAD格式转换不再烦恼
  9. SpringMVC 理论与有用技术(一) 简单、有用、易懂的几个实例
  10. Ubuntu16.04下完美切换Python版本
  11. jQuery als.js 跑马灯
  12. 《精通javascript》-----------------------读书笔记
  13. 算术编码数据压缩Matlab报告,用matlab实现算术编码
  14. Pentaho安装与配置
  15. LCS(最大公共子序列)问题
  16. 第二类斯特林数 - Push Botton Lock(POJ 3088)
  17. 对list的一些常用操作
  18. ibm是被联想收购了吗_联想收购IBM之后为什么出现品牌危机
  19. python twisted安装_图文详解python之twisted模块安装
  20. 华中科技大学有计算机科学与技术学院导师,华中科技大学计算机科学与技术学院导师简介-袁平鹏...

热门文章

  1. 【万里征程——Windows App开发】应用栏
  2. XP+WIN7双系统安装,备份,启动菜单修复
  3. IronPython 与C#交互
  4. 在WEB中实现打印分页
  5. 获取客户端网卡MAC地址和IP地址的几种方法(一)
  6. 容器编排技术 -- Kubernetes kubectl rollout resume 命令详解
  7. 如何在六个月或更短的时间内成为DevOps工程师(三):版本控制
  8. 刷算法题需要的java语法_蓝桥杯java b组需要重点刷什么算法呢?
  9. python读取html文件中的表格数据_使用解析html表pd.read_html文件其中单元格本身包含完整表...
  10. 04737 c++ 自学考试2019版 第二章课后程序设计题 3