GAN是什么?
其全称是Generative Adversarial Networks,即生成式对抗网络,这是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型中有两大模块,生成模型(Generative Model,我们用G来简称),和判别模型(Discriminative,我们用D来简称),GAN的学习过程便是这两个过程之间的博弈对抗,在GAN的理论中,并不要求G和D都是神经网络,只要是能拟合相应生成和判别的函数即可。在这篇中我们会配合简单的代码来解说,并能够实现GAN,这里的GAN代码可以在自己的笔记本等运行。

GAN内部的简单介绍:
首先我们先介绍一下GAN,如图:数据方面,我们有真实的数据,例如一些图片,还有我们自己定义的噪声,也就是一些随机数而已,这些随机数一般是一维(可能是几十个元素),输入到G中,
G对噪声不断进行编码,也就从低维到高维,最终形成一张图片,然后我们将噪声给D和真实图片给D,D进行判别,最后再对结果进行优化。
接下来我们更细的介绍GAN的训练,首先,我们明确一点,G和D是分开训练的,两者是在训练的过程中分别进步的,一开始,我们可以对G和D中的权重随便设置,这个时候,G和D都是几层网络而已,接下来我们将噪声输入到G中,G这个时候生成的东西也不知道是什么牛鬼蛇神,反正是乱七八糟的数据,然后我们人为的去操作,把这乱七八糟的数据输入到D中,我们自己让D对这些数据判别为“假”,然后输入真实图片,我们自己让D判别为真,然后对D的权重和偏置进行优化,再然后就是对G的权重和偏置进行优化,优化的目标是以真实图片训练权重和偏置,之后便是按照此过程不断的训练,也就不断地运行,在运行到100000次(随便说个次数)后,这些权重和偏置也被训练到稳定的状态了,这个时候如果我们输入随机数到G中,G便可以生成图片,而这个图片已经达到能够让D(非人为的,让D自行判断)判断为“真”。
GAN到此时也便完成了。
下面是代码解说
1)首先是导入包(也不用这么多)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

2)然后是读取数据

mnist=input_data.read_data_sets("./fashion_mnist",one_hot=True)

3)再定义一个函数,按照正太分布,专门输出随机值,这个随机值是用在权重和偏置的初始化,而不是噪声

def xavier_init(size):in_dim=size[0]xavier_stddev=1./tf.sqrt(in_dim/2.)return tf.random.normal(shape=size,stddev=xavier_stddev)

4)再下面是判别器各个权重的设定,当然你也可以设置成两层网络,这里的X只生成器生成的图片。

X=tf.placeholder(tf.float32,shape=[None,784])D_w1=tf.Variable(xavier_init([784,256]))
D_b1=tf.Variable(tf.zeros(shape=[256]))D_w2=tf.Variable(xavier_init([256,128]))
D_b2=tf.Variable(tf.zeros(shape=[128]))D_w3=tf.Variable(xavier_init([128,1]))
D_b3=tf.Variable(tf.zeros(shape=[1]))
theta_D=[D_w1,D_w2,D_w3,D_b1,D_b2,D_b3]

5)之后便是判别器的函数定义

def discriminator(x):D_h1=tf.nn.relu(tf.matmul(x,D_w1)+D_b1)D_h2=tf.matmul(D_h1,D_w2)+D_b2D_logit=tf.matmul(D_h2,D_w3)+D_b3D_prob=tf.nn.sigmoid(D_logit)return D_prob,D_logit

6)然后是生成器的输入和权重,偏置的设定

Z=tf.placeholder(tf.float32,shape=[None,100])G_w1=tf.Variable(xavier_init([100,256]))
G_b1=tf.Variable(tf.zeros(shape=[256]))G_w2=tf.Variable(xavier_init([256,784]))
G_b2=tf.Variable(tf.zeros(shape=[784]))theta_G=[G_w1,G_w2,G_b1,G_b2]

7)下面再定义一个函数,专门输出随机数,也就是噪声

def sample_Z(m,n):return np.random.uniform(-1.,1.,size=[m,n])

8)之后当然就是生成器的函数定义,Z就是噪声

def generator(z):G_h1=tf.nn.relu(tf.matmul(z,G_w1)+G_b1)G_log_prob=tf.matmul(G_h1,G_w2)+G_b2G_prob=tf.nn.sigmoid(G_log_prob)return G_prob

9)接下来就是定义一个显示图片的函数,因为我们在跑完数据后,要看看生成数据,也就是生成图片从一开始到最后有着什么样的变化,所以定义这个函数,方便后面实时保存图片,在这个程序中,我们在运行代码时在目录中建立的一个名为out的文件夹,专门存放生成图片

def plot(samples):fig=plt.figure(figsize=(4,4))gs=gridspec.GridSpec(4,4)gs.update(wspace=0.05,hspace=0.05)for i,sample in enumerate(samples):ax=plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28,28),cmap='Greys_r')return fig

10)这一块则是对权重的优化,

G_sample=generator(Z)
D_real,D_logit_real=discriminator(X)
D_fake,D_logit_fake=discriminator(G_sample)D_loss_real=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real,labels=tf.ones_like(D_logit_real)))
D_loss_fake=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake,labels=tf.zeros_like(D_logit_fake)))
D_loss=D_loss_real+D_loss_fakeG_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake,labels=tf.ones_like(D_logit_fake)))D_solver=tf.train.AdamOptimizer().minimize(D_loss,var_list=theta_D)
G_solver=tf.train.AdamOptimizer().minimize(G_loss,var_list=theta_G)mb_size=128
Z_dim=100

11)最后便是开始跑数据了,在tensorflow中,前面的步骤其实就是在建模型,模型建好了,我们最后一步便是初始化,赋值训练模型,我们训练次数为100000,但我觉得,如果要训练好,最好不要低于500000次,大概两三个小时吧可以跑完

sess=tf.Session()
sess.run(tf.global_variables_initializer())
if not os.path.exists('out/'):os.makedirs('out/')
i=0
#开始训练
for it in range(100000):if it%1000==0:samples=sess.run(G_sample,feed_dict={Z: sample_Z(16,Z_dim)})fig=plot(samples)plt.savefig('out/{}.png'.format(str(i).zfill(3),bbox_inches='tight'))i+=1plt.close(fig)X_mb,_=mnist.train.next_batch(mb_size)_,D_loss_curr=sess.run([D_solver,D_loss],feed_dict={X:X_mb,Z:sample_Z(mb_size,Z_dim)})_,G_loss_curr=sess.run([G_solver,G_loss],feed_dict={Z:sample_Z(mb_size,Z_dim)})if it%100==0:print('iter:{}'.format(it))print('D_loss:{:.4}'.format(D_loss_curr))print('G_loss:{:.4}'.format(G_loss_curr))

GAN的介绍和简单代码的实现相关推荐

  1. iptable_netfilter介绍以及简单代码分析

    前言 Linux kernel 版本:5.4.1 Netfilter特指内核中的netfilter框架 iptables指用户空间的配置工具 概念 Linux 上最常用的防火墙工具是 iptables ...

  2. 生成对抗式网络 GAN及其衍生CGAN、DCGAN、WGAN、LSGAN、BEGAN等原理介绍、应用介绍及简单Tensorflow实现

    生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.学界大牛Yann Lecun 曾说,令他最激 ...

  3. DIV布局——仿英雄联盟LOL首页(11页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载

    HTML5期末大作业:仿英雄联盟网站设计--仿英雄联盟LOL首页(11页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 文章目录 HTML5期末大作业:仿英雄联盟网站设计- ...

  4. HTML5期末大作业:生活服务网站设计——生活服务同城商城(33页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载

    HTML5期末大作业:生活服务网站设计--生活服务同城商城(33页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载 常见网页设计作业题材有 个人 ...

  5. HTML5期末大作业:仿英雄联盟网站设计——仿英雄联盟LOL首页(11页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载

    HTML5期末大作业:仿英雄联盟网站设计--仿英雄联盟LOL首页(11页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载 常见网页设计作业题材有 ...

  6. HTML5期末大作业:食品超市网站设计——食品超市-功能齐全(31页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载

    HTML5期末大作业:食品超市网站设计--食品超市-功能齐全(31页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 DW学生个人网站制作成品下载 常见网页设计作业题材有 个 ...

  7. 25TML5期末大作业:影视网站设计——电影请以你的名字呼唤我(4页) 大学生简单个人静态HTML网页设计作品 DIY布局个人介绍网页模板代码 DY学生个人网站制作成品下载

    HTML5期末大作业:影视网站设计--电影请以你的名字呼唤我(4页) 大学生简单个人静态HTML网页设计作品 DIY布局个人介绍网页模板代码 DY学生个人网站制作成品下载 常见网页设计作业题材有 个人 ...

  8. TML5期末大作业:影视网站设计——电影请以你的名字呼唤我(4页) 大学生简单个人静态HTML网页设计作品 DIY布局个人介绍网页模板代码 DY学生个人网站制作成品下载

    HTML5期末大作业:影视网站设计--电影请以你的名字呼唤我(4页) 大学生简单个人静态HTML网页设计作品 DIY布局个人介绍网页模板代码 DY学生个人网站制作成品下载 常见网页设计作业题材有 个人 ...

  9. HTML+CSS大作业web网页设计实例作业 ——中国民间年画 (5页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码

    web网页设计实例作业 --中国民间年画 (5页) 大学生简单个人静态HTML网页设计作品 DIV布局个人介绍网页模板代码 常见网页设计作业题材有 个人. 美食. 公司. 学校. 旅游. 电商. 宠物 ...

最新文章

  1. 利用WiFi模块实现MicroPython远程开发
  2. 人工智能将再创新高,清华发布人工智能白皮书
  3. python 示例_Python使用示例设置add()方法
  4. https p12证书请求解决问题过程
  5. Python使用pip自动升级所有第三方库
  6. 图片 过度曝光_实际拍摄中,经常遇到曝光不足或过曝的结果,6种手段帮你解决...
  7. 【Oracle】SQLPLUS命令
  8. STM32单片机使用注意事项
  9. Android包管理机制2 PackageInstaller安装APK
  10. 如何防止三分钟热度?给自己的目标定个阶段性奖励吧
  11. error C4716 必须返回一个值 处理
  12. Ansible之管理windows主机
  13. web——216中安全色
  14. 简单的Java 16方格排序游戏
  15. 小丁的Spring笔记一(概述)
  16. unity内部自带局域网制作
  17. hadoop 报错 there appears to be a gap in the edit log. we expected txitd 1, but got txid 14444
  18. 微信开发者工具中的版本管理功能搭配gitee使用
  19. 朗润外盘国际期货:ChatGPT这个人工智能有点东西
  20. 关于DBeaver stored procedure中print语句的内容看不见,smss可以的问题

热门文章

  1. 3030. 天黑请闭眼
  2. 嵌入式—LM3S1138介绍
  3. 《 指数基金投资指南 》by 银行螺丝钉 - 笔记 - 4 - 第一部分
  4. 【赛码网 牛客网】输入输出总结(python版)
  5. 使用浏览器转化ASCII码为字符
  6. 如何快速恢复最近关闭的浏览器标签页面
  7. kubernetes Pod Lifecycle生命周期与livenessProbe、 readinessProbe探测方法
  8. 【Cocos2D-x 3.5实战】坦克大战(2)游戏开始界面
  9. 高中数学向量巨难题型四心问题解题技巧
  10. C#获取字符串的拼音和首字母