图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)

文章目录

  • 图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)
    • 1.前言
    • 2.GAN网络主体架构介绍
    • 3.模型搭建
      • 3.1 生成器的搭建
      • 3.2 判别器模型的搭建
    • 4.数据预处理
    • 5.定义训练各项参数以及训练步骤
    • 6.训练以及效果可视化评估
    • 结语

1.前言

​ 深度学习,中有一个较为成熟并且非常重要的方向,GAN图像对抗生成网络,该网络在图像生成,图像增强,风格化领域,以及在艺术的图像创造(博主也是在看到一个关于中国山水画的GAN生成上,有了学习GAN的兴趣)有重要的作用。

​ 那么正所谓柿子要挑软的捏,学习从最简单的开始,在GAN方面完全是萌新的博主,今天介绍的自然也不是什么太难的架构,在本篇博客中,我会介绍GAN的大致架构,并用较为简单的方式从头到尾 (模型搭建,定义训练参数,训练步骤)实现他,如果本篇博客对你有帮助的话,别忘记点个赞。

( ̄▽ ̄)~■干杯□~( ̄▽ ̄)

2.GAN网络主体架构介绍

GAN的网络总体架构其实非常简单,他的中文名字对抗生成网络,意思是在他模型中包含两个网络,生成网络,对抗网络,总体结构如下图:

​ 我们可以看到这张图上包含了两个网络Generator图像生成器和Discriminator图像分辨器,他的工作原理简单来说是这样的:我们的目标是想要一个生成图片那么我们如何去训练这个呢,这里GAN的开发者提出了这么一个想法,我们训练一个判别器,训练一个生成器,输入噪声(也就是我们提前规定好形状的随机初始化的向量)然后产生了图片,然后我们将真实的图片与虚假图片一起输入判别器,判别图片是否是真实的,利用在这里产生的损失去训练生成器,与判别器。那么我们可以想想如果这样的话我们最终产生的理想结果就是,判别器最终无法判别生成器生成的图片是真是假,最终预测的概率只有0.5(真假 二分类随机乱猜的概率)。

​ 当刚看懂网络工作方式的时候,我简直惊呆了,这是多么神奇的思维啊,生成器在训练中由于损失控制会努力希望生成的图片被判别为真,而判别器是希望能完全给出正确的判断(给生成的图片的判断全为0,真的图片判断全为1),那么在这两个模型的训练之间,他们在互相对抗,我们最终得到的将会是一个非常好的图像生成器,和 自编码器相比,(直接计算生成图与原图的差距)效果会更好(这里我会在之后的mnist数据生成展示中展示编码器与GAN网络的差别)。

3.模型搭建

​ 那么在介绍完模型之后,我们趁热打铁,直接开始模型的搭建,在上文中,我提到了两个模型负责生成图片的生成器,负责判别图片真假的判别器,接下来我开始分别搭建这两个模型。

3.1 生成器的搭建

​ 这里我们要搭建的是一个能够接收我们产生的随机初始化的向量,然后产生图片(这里我们产生的数据是mnist的手写体数字)的模型,这里我为了简单化全部采用全连接层来写模型

import tensorflow as tf
keras=tf.keras
layers=keras.layers
def generator_model():model=keras.Sequential()#model.add(layers.Dense(256,input_shape=(100,),use_bias=False))#输入形状100是我输入噪声的形状,生成器一般都不使用BIASmodel.add(layers.BatchNormalization())model.add(layers.LeakyReLU())#GAN中一般使用LeakyRelu函数来激活model.add(layers.Dense(512,use_bias=False))#生成器一般都不使用BIASmodel.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))#生成能够调整成我们想要的图片形状的向量model.add(layers.BatchNormalization())model.add(layers.Reshape((28,28,1)))#这里进行修改向量的形状,可以直接使用layers的reshapereturn model

可以看到经过全连接层的这样处理,我们输出的会是一个形状大小为(28,28,1)的图片,那么生成器的任务就是判断输入图片是否是生成的,也就是输入图片,输出0,1一个非常简单的二分类问题,那么我们就按照这个思路搭建我们的判别器网络。

3.2 判别器模型的搭建

判别器这里我也使用最基础的全连接层来创建(一方面是减少计算量,一方面是测试一下Dense层的效果)

def discriminator_model():model=keras.Sequential()model.add(layers.Flatten())#图片是一个三维数据,要输入到全连接层之前,先使用flatten层压平为一维的model.add(layers.Dense(512,use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(256,use_bias=False))model.add(layers.BatchNormalization())model.add(layers.LeakyReLU())model.add(layers.Dense(1))#最后输出为0,1只需要一层return model

那么定义完组成模型的两个重要架构之后,我们为了接下来的训练需要准备处理好的数据,所以这里我们开始处理数据。

4.数据预处理

在本篇最简单的实战中,我采用深度学习中使用次数最多,入门级Hello World数据集,mnist手写体数据集,由数万张手写体数字组成

这里的数字都是手写之后,经过特殊处理最终保存下来的。可以看到这样生成的图片是非常带有个人风格的(这写的也不太整齐。。),那么利用GAN生成网络去生成能类似人写的数据,达到可以欺骗人眼的效果,就是我此次的目的,那么废话少说,就开始我们此次的数据准备。

(x_train,y_train),_=keras.datasets.mnist.load_data()
x_train=tf.expand_dims(x_train,axis=-1)#这里由于输入的手写体是只有两个维度的,所以这里我扩展最后一个维度
x_train.shape
TensorShape([60000, 28, 28, 1])

扩展完维度后,为了方便模型运算,我们需要将数据进行归一化,规定数据集的BATCH_SIZE

x_train=tf.cast(x_train,tf.float32)
x_train=x_train/255.0
x_train=x_train*2-1#将图片数据规范到[-1,1]
BATCH_SIZE=256
BUFFER_SIZE=60000#每次训练弄乱的大小
dataset=tf.data.Dataset.from_tensor_slices(x_train)
dataset=dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

5.定义训练各项参数以及训练步骤

在搭建完模型,预处理好数据之后,接下来就需要定义模型所需优化器,损失计算函数,以及训练步骤

loss_object=keras.losses.BinaryCrossentropy(from_logits=True)#损失这里使用二分类交叉熵损失,没有激活是logits
def discriminator_loss(real_out,fake_out):     real_loss=loss_object(tf.ones_like(real_out),  real_out)fake_loss=loss_object(tf.zeros_like(fake_out),fake_out)return real_loss+fake_loss
#这里判别器使用的损失是计算我们人为制造的0,1标签与判别器模型输出的做计算,最终返回二者相加
def generator_loss(fake_out):fake_loss=loss_object(tf.ones_like(fake_out),fake_out)return fake_loss
#生成器计算损失当然是希望判别器都把他当真,所以是与1做计算generator_opt=keras.optimizers.Adam(1e-4)
discriminator_opt=keras.optimizers.Adam(1e-4)#定义两个模型的优化器

定义完了在训练吗中需要用到的优化器损失函数,我们这里接下来定义模型,训练步骤并开始训练(这里我们会在每次训练后绘画随机生成的图片,来观察我们图像生成模型的效果,所以这里我会提前制作一个随机种子)

EPOCHS=100
noise_dim=100 #输入噪声的维度
num=16 #每次随机绘画16张图
seed=tf.random.normal(shape=([num,noise_dim])) #制作用于生成图片的向量
gen_model=generator_model()
dis_model=discriminator_model()
#初始化这两个模型
#定义训练步骤
@tf.function
def train_step(images):noise=tf.random.normal([BATCH_SIZE,noise_dim])with tf.GradientTape() as gentape, tf.GradientTape() as disctape:real_output=dis_model(images,training=True)fake_image=gen_model(noise,training=True)fake_output=dis_model(fake_image,training=True)gen_loss=generator_loss(fake_output)dis_loss=discriminator_loss(real_output,fake_output)grad_gen=gentape.gradient(gen_loss,gen_model.trainable_variables)grad_dis=disctape.gradient(dis_loss,dis_model.trainable_variables)generator_opt.apply_gradients(zip(grad_gen,gen_model.trainable_variables))discriminator_opt.apply_gradients(zip(grad_dis,dis_model.trainable_variables))#在每次训练后绘图
def generate_plot_img(gen_model,test_noise):pre_img=gen_model(test_noise,training=False)fig=plt.figure(figsize=(4,4))for i in range(pre_img.shape[0]):plt.subplot(4,4,i+1)plt.imshow((pre_img[i, :, :, 0]+1)/2,cmap='gray')#这里cmap限定绘图的颜色空间,灰度图plt.axis('off')plt.show()#将16张图片一起显示出来

6.训练以及效果可视化评估

那么我们开始训练

def train(dataset, epochs):for epoch in range(epochs):for img in dataset:train_step(img)print('-',end='')generate_plot_img(gen_model,seed)#绘制图片
train(dataset,EPOCHS)#这里EPOCHS我设置为100

那么由于我的随机数种子是固定的,所以这里我们随机生成的图片每次都是固定的数字,所以我们是可以看到效果在不断变好,如下

这是第一次训练结束后生成的一团浆糊

这是第五次训练产生的图像,可以看到已经渐渐产生了有数字的轮廓,

在经过100次训练后最终我们看到我们的图像生成器,最后产生的图片已经非常有手写数字的轮廓。

虽然效果仍然不是很好,但其实是由于我这里完全使用了全连接层,在图像处理领域使用卷积神经网络会更好的效果,下图是我使用了卷积神经网络后的效果:

结语

在本篇博客中,我完成了一个非常简单的GAN生成对抗网络,并训练该模型使得他可以生成非常接近的手写体的真实数据,对本篇博客有疑问或者建议的同学欢迎评论区交流。

图像对抗生成网络 GAN学习01:从头搭建最简单的GAN网络,利用神经网络生成手写体数字数据(tensorflow)相关推荐

  1. ufldl学习笔记与编程作业:Multi-Layer Neural Network(多层神经网络+识别手写体编程)...

    ufldl学习笔记与编程作业:Multi-Layer Neural Network(多层神经网络+识别手写体编程) ufldl出了新教程,感觉比之前的好,从基础讲起,系统清晰,又有编程实践. 在dee ...

  2. 2018-4-15摘录笔记,《网络表征学习前沿与实践》 崔鹏以及《网络表征学习中的基本问题初探》 王啸 崔鹏 朱文武

    1.来源:<网络表征学习前沿与实践>  崔鹏 (1)随着数据的增加以及计算机计算速度的增加,想当然的以为速度快了,数据再多也是可以自己算的,但是若是数据之间存在着复杂的关系,那么处理一个样 ...

  3. 搭建一个简单的SDN网络环境

    第1小题:简单网络 说明:由于对于SDN架构的理解在学界和业界并没有统一,为了方便参赛队员选择,对于初学者,大赛推荐OpenFlow作为南向接口来实现SDN环境,以下给出分别针对采用OpenFlow和 ...

  4. GAN学习指南(通俗易懂):从原理入门到制作生成Demo

    这篇博客是本人参考了别的作者的文章所编写的,里面有自己总结的也有参考的,讲得比较通俗易懂,对GAN入门非常有帮助. 本文主要分为三个部分: 介绍原始的GAN的原理 同样非常重要的DCGAN的原理 如何 ...

  5. 千锋教育+计算机四级网络-计算机网络学习-01

    目录 课程链接 最早的广域网 计算机网络发展阶段 计算机网络的定义与要点 英文单词网络术语与解释 计算机网络分类 广域网技术 城域网 局域网 个人局域网 五种基本的网络拓扑结构​ 误码率 电路交换网特 ...

  6. 【Linux网络编程学习】使用socket实现简单服务器——多进程多线程版本

    此为牛客Linux C++课程和黑马Linux系统编程笔记. 1. 多进程版 1.1 思路 大体思路与上一篇的单进程版服务器–客户端类似,都是遵循下图: 多进程版本有以下几点需要注意: 由于TCP是点 ...

  7. Scala学习笔记-环境搭建以及简单语法

    关于环境的搭建,去官网下载JDK8和Scala的IDE就可以了,Scala的IDE是基于Eclipse的. 下面直接上代码: 这是项目目录: A是scala写的: package first impo ...

  8. 最简单的验证码(利用JSP生成验证码)

    可以直接利用jsp输出验证码.jsp文件如下: <%@ page contentType="image/jpeg"import="java.awt.*, java. ...

  9. 【论文翻译 IJCAI-20】Heterogeneous Network Representation Learning 异构网络表示学习

    文章目录 摘要 1 引言 2 异构网络挖掘 3 异构网络表示 3.1 异构网络嵌入 3.2 异构图神经网络 3.3 知识图谱与属性网络 3.4 应用 4 挑战.方向和开源数据 4.1 未来方向 避免设 ...

最新文章

  1. 微软最新启动了一个 I'm 活动
  2. “vector”: 不是“std”的成员_C++ vector成员函数实现[持续更新]
  3. 修改IIS7并发连接数目限制
  4. 使用Google Weather API查询天气预报
  5. CF1119H-Triple【FWT】
  6. qt与JAVA服务器通信_Qt实现的SSL通信客户端和服务器
  7. OpenShift 4 - Istio-Tutorial (1) 教程说明和准备环境
  8. 工欲善其事,必先利其器之sublime
  9. Spring事务异常回滚
  10. pymysql executemany()函数
  11. CTeX书写规范、WinEdt编写XeLaTeX、数模格式编写总结
  12. SHELL编程基础 By jackie
  13. 夜神模拟器——最好用的安卓模拟器
  14. ChatGPT|微信快速接入ChatGPT
  15. 软件测试面试题【变态逻辑题】,盘点那些大厂面试必出变态逻辑题
  16. 自制万能xp镜像让重做系统变得简单
  17. npm 包管理及 registry 或 proxy 配置
  18. [三星移动硬盘] 磁盘必须经过格式化(无法显示)
  19. 今日头条引流脚本,微商引流工具
  20. mybatisplus--getOne和逻辑删除问题详解

热门文章

  1. 拼题A基础篇 10 统计字符
  2. DITA与DocBook对比分析
  3. 数字集群对讲系统服务器,图解数字集群对讲机通讯系统
  4. 机器学习方法的PPT
  5. mondrian mysql 实例_mondrian 导入demo数据到mysql
  6. 入职两年的人写给刚入职的人
  7. 协调计算机硬件和软件的中间层次,计算机的软件和网络基础
  8. 2019第十届蓝桥杯国赛c++B组真题及题解
  9. 精益软件开发的思想_精益软件开发原理快速指南
  10. 联想微型计算机b320电源线,联想B320:拥有电视机的接口