在弄毕设的时候,室友的毕设是基于DCGAN实现音乐的自动生成。那是第一次接触对抗神经网络,当时听室友的描述就是两个CNN,一个生成一个监测,在互相博弈。
最近我关注的一个大神在弄有关于GAN的东西,所以就跟着学了一下,蛮有意思的,和之前的深度学习略有不同。

1.导入库

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import sys,os,pathlib,imageio

2.基本原理

生成式对抗网络(GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。2014年由lanGoodfellow引入深度学习领域,被评价为“20年来深度学习领域最酷的想法”。
机器学习的模型大体上可分为两类,生成模型和判别模型。判别模型需要输入变量,通过某种模型来预测。生成模型是给定某种隐含信息,来随机产生观测数据。在之前的深度学习实验中,都是使用判别模型,来实现对某种事务的判别,例如:猫狗大战、鸟类识别、手写数字识别等。而生成模型接触的并不多。GAN是更好的生成模型
GAN主要包括了两个部分:生成器generator与判别器discriminator。生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,从而骗过判别器。而判别器则需要对接收的图片进行真假判别。

在训练过程中,生成器努力地令生成的图像更加真实,而判别器则努力地去识别图像的真假,这个过程相当于二人博弈,随着时间的推移,生成器和判别器在不断地进行对抗。最终两个网络达到了一个动态均衡:生成器生成的图像接近于真是图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)。

利用GAN生成手写数字识别的流程图如下所示:

对于给定的真实图片,判别器要为其打上标签1;
对于给定的生成图片,判别器要为其打上标签0;
对于生成器传给辨别器的生成图片,生成器希望辨别器打上标签1.

GAN步骤:

1.生成器(Generator)接收随机数并返回生成图像。
2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

3.数据准备

在这一阶段我们导入真实的手写数字,对其进行打乱、batch、归一化等操作。

(train_images,train_labels) ,(_,_) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images - 127.5)/127.5#归一化到[-1,1]之间
batch_size = 256
buffer_size = 60000
datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(buffer_size).batch(batch_size)

4.生成器与判别器的构建

def Generator_model():#最终生成28*28*1的图片model = tf.keras.Sequential([tf.keras.layers.Dense(256,input_shape=(100,)),#传入的数据为长度为100的随机向量tf.keras.layers.BatchNormalization(),#归一化tf.keras.layers.LeakyReLU(),#高级一点的Relu函数tf.keras.layers.Dense(512),tf.keras.layers.BatchNormalization(),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dense(28*28*1,activation='tanh'),tf.keras.layers.BatchNormalization(),tf.keras.layers.Reshape((28,28,1))#最后调整为(28,28,1)形状的数据,与手写数字的shape一致,作为生成器生成的图片])return modeldef Discriminator_model():#判断图片是真正的图片还是生成的model = tf.keras.Sequential([tf.keras.layers.Flatten(),#传入一张图片,将其展开成一维数组tf.keras.layers.Dense(512),tf.keras.layers.BatchNormalization(),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dense(256),tf.keras.layers.BatchNormalization(),tf.keras.layers.LeakyReLU(),tf.keras.layers.Dense(1,activation='sigmoid')])return model
generator = Generator_model()
discriminator = Discriminator_model()

5.生成器与判别器的loss构建

判别器的loss值:判断真实图片为1的loss与判断生成图片为0的loss之和。因为判别器希望将真实图片判别为1,将生成图片判别为0.
生成器的loss值:判断生成图片为1的loss。因为生成器希望生成的图片是真实图片,即判别为1.

#生成器losses
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def Discriminator_loss(real_out,fake_out):real_loss = cross_entropy(tf.ones_like(real_out),real_out)fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)return real_loss+fake_loss
def Generator_loss(fake_out):return cross_entropy(tf.ones_like(fake_out), fake_out)
Generator_opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
Discriminator_opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

参数设置

epochs = 100
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate,noise_dim])#16个长度为100的向量

6.批次训练

对一个batch_size的数据进行训练

def train_step(images):noise = tf.random.normal([batch_size,noise_dim])#生成一个batch_size*noise_dim的数据,相当于生成了batch_size个长度为100的随机向量with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:#两个Tape,一个代表生成器,一个代表判别器。real_out = discriminator(images,training = True)#利用判别器对真实的图片进行训练,得到一个modelgen_image = generator(noise,training = True)#利用生成器对噪声数据生成图片fake_out = discriminator(gen_image, training=True)#利用判别器对生成的图片进行训练gen_loss = Generator_loss(fake_out)#利用判别器对生成图片的判断计算生成器的loss值dis_loss = Discriminator_loss(real_out,fake_out)##利用判别器对生成图片和真实图片的判断计算判别器的loss值gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)#根据生成器的loss值和网络模型计算梯度gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)#根据判别器的loss值和网络模型计算梯度Generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))#根据梯度对生成器进行梯度更新Discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))#根据梯度对判别器进行梯度更新

7.训练&&可视化

def train(dataset,epochs):for epoch in range(epochs):#一共训练epochs次for image_batch in dataset:#对dataset中的每一个batch进行训练train_step(image_batch)print('.',end='')print()Generator_plot_image(generator,seed,epoch)#根据训练好的生成器,对之前生成的seed进行处理,生成图片
train(datasets,epochs)
def Generator_plot_image(gen_model,test_noise,epoch):pre_images = gen_model(test_noise,training = False)#根据test_noise生成图片,生成器设置为不可训练fig = plt.figure(figsize=(4,4))for i in range(pre_images.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray')#之前归一化为[-1,1]之间,现在+1然后除以2,使之在[0,1]之间plt.axis('off')fig.savefig("E:/tmp/.keras/datasets/number_gen/%05d.png" % epoch)plt.close()

生成图片如下所示:

8.生成动图

该模块参考大神K同学啊

def compose_gif():# 图片地址data_dir = "E:/tmp/.keras/datasets/number_gen"data_dir = pathlib.Path(data_dir)paths = list(data_dir.glob('*'))gif_images = []for path in paths:print(path)gif_images.append(imageio.imread(path))imageio.mimsave("E:/tmp/.keras/datasets/test.gif", gif_images, fps=2)
compose_gif()

文件太大,csdn忍不了无法上传。

由于训练速度等原因,epochs设置的是100,最终展示的效果并不是很好,但是也可以看出生成的图片由一片模糊向逐渐清晰的过渡。

努力加油a啊

深度学习之基于GAN实现手写数字生成相关推荐

  1. 深度学习之基于DCGAN实现手写数字生成

    该篇文章与上篇文章内容相差不多,但是主要的网络结构不同,上篇文章采用的是GAN网络结构,而这篇文章采用的是DCGAN网络结构.两者的差异在于以下几点: (1)使用卷积和去卷积代替池化层. (2)在生成 ...

  2. 深度学习导论(5)手写数字识别问题步骤

    深度学习导论(5)手写数字识别问题步骤 手写数字识别分类问题具体步骤(Training an handwritten digit classification) 加载数据 显示训练集中的图片 定义神经 ...

  3. 深度学习 第三章 tensorflow手写数字识别

    深度学习入门视频-唐宇迪 (笔记加自我整理) 深度学习 第三章 tensorflow手写数字识别 1.tensorflow常见操作 这里使用的是tensorflow1.x版本,tensorflow基本 ...

  4. Java软件研发工程师转行之深度学习(Deep Learning)进阶:手写数字识别+人脸识别+图像中物体分类+视频分类+图像与文字特征+猫狗分类

    本文适合于对机器学习和数据挖掘有所了解,想深入研究深度学习的读者 1.对概率基本概率有所了解 2.具有微积分和线性代数的基本知识 3.有一定的编程基础(Python) Java软件研发工程师转行之深度 ...

  5. 深度学习100例-生成对抗网络(GAN)手写数字生成 | 第18天

    文章目录 一.前期工作 1. 设置GPU 2. 定义训练参数 二.什么是生成对抗网络 1. 简单介绍 2. 应用领域 三.网络结构 四.构建生成器 五.构建鉴别器 六.训练模型 1. 保存样例图片 2 ...

  6. 生成对抗网络(GAN)——MNIST手写数字生成

    前言 正文 一.什么是GAN 二.GAN的应用 三.GAN的网络模型 对抗生成手写数字 一.引入必要的库 一.引入必要的库 二.进行准备工作 三.定义生成器和判别器模型 四.设置损失函数和优化器,以及 ...

  7. [Python人工智能] 三十.Keras深度学习构建CNN识别阿拉伯手写文字图像

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN). ...

  8. 深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

    大家好,我是微学AI,今天给大家带来手写OCR识别的项目.手写的文稿在日常生活中较为常见,比如笔记.会议记录,合同签名.手写书信等,手写体的文字到处都有,所以针对手写体识别也是有较大的需求.目前手写体 ...

  9. [Python图像识别] 四十七.Keras深度学习构建CNN识别阿拉伯手写文字图像

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

最新文章

  1. DroidPilot 发布微信公众帐号啦~
  2. Servlet应用之细节
  3. 命令行执行Junit测试
  4. C学习杂记(六)%2.0f打印输出宽度
  5. EF架构~了解一下,ADO.NET Entity Framework
  6. 复习Java的精华总结
  7. Springmvc借助SimpleUrlHandlerMapping实现接口开关功能
  8. 使用jaxb根据xsd逆向生成java代码
  9. L3-020 至多删三个字符 (30 分) DP
  10. c#调用带有安全认证的java webservice
  11. Mac锁屏的几种方式
  12. 浪潮激荡大时代,存储起航新十年
  13. 计算机网络怎么看ip地址类别,如何查找ip地址 ip地址分类介绍【图解】
  14. IS-IS 路由选择协议入门
  15. 计算机定时开机关机设置,怎样设置电脑定时开机关机
  16. Unity3d Mesh、Texture、UI 压缩降低内存
  17. 2022元宇宙共享大会|倪健中:我们正在开启元宇宙新时代
  18. 程序猿生存定律-六个程序猿的故事(2)
  19. 苏黎世联邦理工学院计算机系研究生,大神offer | 恭喜四位再来人学员斩获苏黎世联邦理工学院-电子工程与信息技术硕士 !...
  20. 程序员的十个等级(最详尽)

热门文章

  1. .NET的Snk使用方法
  2. IOS UIPageControl的设置点为一张图片
  3. linux namespace 工具,Linux Namespace : 简介
  4. java actor和线程有什么区别_Scala Actor与java并发编程的区别
  5. Java基础复习-八大基本数据类型-内存模型-基本算法-网络编程
  6. Android开发之RecyclerView之刷新数据notifyDataSetChanged失败的问题
  7. 计算机设备的热量,帮我计算机一下这块冰能吸收多少热量?
  8. sourcetree 拉取 一直让输入密码
  9. 最终计算供应链管理生产计划排程逻辑管理
  10. win7-安装phantomjs,并添加环境变量。