https://blog.csdn.net/CoderPai/article/details/70598403?utm_source=blogxgwz0

里面有比较全面的GAN的链接

原始论文链接:http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

一篇不错的理解GAN的文章:https://blog.csdn.net/qq_31531635/article/details/70670271

这篇GAN代码的出处 https://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/

简单用了别人的代码,实现了一下,加入了自己理解的部分:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_datasess = tf.InteractiveSession()mb_size = 128
Z_dim = 100mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)def weight_var(shape, name):return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer())def bias_var(shape, name):return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0))# discriminater net
#普通的两层卷积网络,作为鉴别网络
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')D_W1 = weight_var([784, 128], 'D_W1')
D_b1 = bias_var([128], 'D_b1')D_W2 = weight_var([128, 1], 'D_W2')
D_b2 = bias_var([1], 'D_b2')theta_D = [D_W1, D_W2, D_b1, D_b2]# generator net
# 两层网络,输入为100维的噪声,这里是[-1,1]的均匀噪声,作为生成网络
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')G_W1 = weight_var([100, 128], 'G_W1')
G_b1 = bias_var([128], 'G_B1')G_W2 = weight_var([128, 784], 'G_W2')
G_b2 = bias_var([784], 'G_B2')theta_G = [G_W1, G_W2, G_b1, G_b2]#具体网络的结构def generator(z):G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)G_log_prob = tf.matmul(G_h1, G_W2) + G_b2G_prob = tf.nn.sigmoid(G_log_prob) #使用sigmoid给出该位置的值return G_probdef discriminator(x):D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)D_logit = tf.matmul(D_h1, D_W2) + D_b2D_prob = tf.nn.sigmoid(D_logit) #使用sigmoid给出该位置的值return D_prob, D_logitG_sample = generator(Z)
#X为实际的样本数据,G_sample为生成的样本数据
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)
'''
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))
'''
#D为辨别器,G为生成器,这里G是有助于辨别器提高性能的,G的输入是随机噪声,如果G没有训练,产出的应该是无关样本
#虽然也会提高一点性能,但是肯定不好,这里是希望G可以将噪声映射到合理的数字图的空间上,这里希望产出应该很接近
#合理图片,那么D很有可能被判别图像为真实图像,所以G_loss最小化对应为G_sample被D网络认定成真实图像
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))#D和G还是比较独立的两部分,分别写开,不过两部分需要互相提高,所以后续的训练应该是交替进行
D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
#随机数生成
def sample_Z(m, n):'''Uniform prior for G(Z)'''return np.random.uniform(-1., 1., size=[m, n])def plot(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):  # [i,samples[i]] imax=16ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28, 28), cmap='Greys_r')return figif not os.path.exists('out/'):os.makedirs('out/')sess.run(tf.global_variables_initializer())i = 0
for it in range(1000000):#每1000次输出一次if it % 1000 == 0:samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})  # 16*784fig = plot(samples)#图像存储plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')i += 1plt.close(fig)X_mb, _ = mnist.train.next_batch(mb_size)#D和G的交替训练,进行性能的互相提高_, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})_, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})if it % 1000 == 0:print('Iter: {}'.format(it))print('D loss: {:.4}'.format(D_loss_curr))print('G_loss: {:.4}'.format(G_loss_curr))print()

运行的结果是每1000次生成的图像,前200张做成了视频,结果如下:

总结一下:

1.GAN还是一种解决问题的框架,通过生成网络G产生更高相关度的图像来提升判别网络D的性能

2.本文方法仅仅使用了神经网络,没有使用CNN,使用CNN产生判别网络才是更合适的,本文在产生的200张的动图后半部分图像变化不明显,将Loss显示出来也可以看出来结果提升不高,这是网络结构较差的原因。

3.GAN不止是提高判别网络D的性能,也可以通过GAN的生成网络G产生平价数据(弱监督中,有标签数据较少,无标签数据较多的情形)。可以通过GAN的生成网络产生更多的有标签数据用于训练,论文见SSGAN。博客链接:https://blog.csdn.net/shenxiaolu1984/article/details/75736407

之后要看的一些东西:

1.Fine-tunning,复用别人的网络并进行新的应用开发,学习网址链接:https://blog.csdn.net/u011600477/article/details/78607883

2.RCNN,博客见:https://blog.csdn.net/v1_vivian/article/details/78599229?utm_source=blogxgwz0,https://blog.csdn.net/WoPawn/article/details/52133338。RCNN关键在于预搜索框的选取,论文是Selective Search for Object Recognition,博客链接:https://blog.csdn.net/surgewong/article/details/39316931

3.YOLO: You Only Look Once,和RCNN的功能一致,但是想法不同,博客:https://blog.csdn.net/shenxiaolu1984/article/details/78826995

4.可视化网络结构特征

tensorflow学习(7. GAN实现MNIST分类)相关推荐

  1. TensorFlow 学习(3)——MNIST机器学习入门

    通过对MNIST的学习,对TensorFlow和机器学习快速上手. MNIST:手写数字识别数据集 MNIST数据集 60000行的训练数据集 和 10000行测试集 每张图片是一个28*28的像素图 ...

  2. TensorFlow学习笔记之四(MNIST数字识别)

    文章目录 1. 关于MNIST数据集 2. 前向传播确定网络结构 2.1 涉及的方法 1. 关于MNIST数据集 数据集和input_data文件 有6万张28*28像素点的0~9手写数字图片和标签, ...

  3. pytorch学习之GAN生成MNIST手写数字

    0.简单介绍: 学深度学习的人必然知道,最基本的GAN模型由一个生成器 G 和判别器 D 组成.生成器用于生成假样本,判别器用于判断样本是真实的还是假的. 在整个训练过程中,生成器努力地让生成的图像更 ...

  4. TensorFlow学习笔记——车牌标志识别分类

    TensorFlow--车牌标志识别分类学习笔记 本博客的内容是:在 BlackWalnut Labs 完成车牌标志识别实验的学习笔记 其中主要的部分是 BlackWalnut Labs 的实验过程介 ...

  5. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  6. tensorflow学习笔记(十):GAN生成手写体数字(MNIST)

    文章目录 一.GAN原理 二.项目实战 2.1 项目背景 2.2 网络描述 2.3 项目实战 一.GAN原理 生成对抗网络简称GAN,是由两个网络组成的,一个生成器网络和一个判别器网络.这两个网络可以 ...

  7. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  8. TensorFlow学习笔记(二):快速理解Tutorial第一个例子-MNIST机器学习入门 标签: 机器学习SoftmaxTensorFlow教程 2016-08-02 22:12 3729人阅

    TensorFlow学习笔记(二):快速理解Tutorial第一个例子-MNIST机器学习入门 标签: 机器学习SoftmaxTensorFlow教程 2016-08-02 22:12 3729人阅读 ...

  9. [深度学习-实践]GAN基于手写体Mnist数据集生成新图片

    系列文章目录 深度学习GAN(一)之简单介绍 深度学习GAN(二)之基于CIFAR10数据集的例子 深度学习GAN(三)之基于手写体Mnist数据集的例子 深度学习GAN(四)之PIX2PIX GAN ...

最新文章

  1. vue中点击导航栏部分,页面切换
  2. python 因果推断_KDD 2018:微软推出用于因果推断的Python库
  3. [转]Python 命令行参数和getopt模块详解
  4. 锗钛项圈真的可以治颈椎病吗
  5. Wireshark网络抓包(一)——数据包、着色规则和提示
  6. Android Studio开发基础之自定义View组件
  7. Java集合(7)--Map接口的实现类HashMap、LinkHashMap、TreeMap和Properties
  8. Mac 上开启一个简单的服务器
  9. VMware vCloud Director视频教程
  10. uniapp-微信小程序直播插件小记
  11. Origin软件的安装
  12. 照片的分辨率怎么调整?图片分辨率太低怎么调高?
  13. CentOS8 离线安装 汉语拼音
  14. CWnd与CDialog-DoModal与ShowWindow区别
  15. java中多个if语句如何简化_8种if else语句简化方法
  16. RxJS 6 —— Subscription
  17. C++ Opencv binarization thinning and bone processing
  18. 失联客机大致位置确认 美军水下航行器展开搜寻
  19. C++ tuple的介绍及使用
  20. 涨知识!你不知道的中国手机号码的编码和划分规则

热门文章

  1. python装饰器类型错误_有没有办法在继承期间保持装饰器? - python
  2. 【LeetCode】0046.全排列 (递归详解)
  3. 10分钟看懂, Java NIO 底层原理
  4. Tomcat maven 插件启动出现tomcat\conf\tomcat-users.xml cannot be read异常解决方法
  5. Map的4种遍历方法
  6. Spring AOP 简介以及简单用法
  7. Oracle - Log buffer 的相关设置
  8. Oracle 11gR2 安装 example(范例数据库)
  9. k8s 多租户_k8s使用rbac实现多租户
  10. P1525关押罪犯(并查集补集)