版权声明:本文为博主原创文章,未经博主允许不得转载。

这几天在看GAN模型的时候,顺便关注了另外一种生成模型——VAE。其实这种生成模型在早几年就有了,而且有了一些应用。著名黑客George Hotz在其开源的自主驾驶项目中就应用到了VAE模型。这其中的具体应用在我上一篇转载的博客comma.ai中有详细介绍。在对VAE基本原理有一个认识后,就开始啃代码了。在代码中最能直观的体现其思想,结合理论来看,有利于理解。理论介绍,网上资料很多,就不赘述了。

基本网络框架:

下面是基于keras的VAE代码解读:(自己加了一层,发现收敛的快一点了,附实验结果)

[html] view plaincopy
  1. <span style="font-size:14px;">from __future__ import division
  2. from __future__ import print_function
  3. import os.path
  4. import numpy as np
  5. np.random.seed(1337)  # for reproducibility
  6. import tensorflow as tf
  7. from tensorflow.examples.tutorials.mnist import input_data
  8. mnist = input_data.read_data_sets('MNIST')
  9. input_dim = 784
  10. hidden_encoder_dim_1 = 1000
  11. hidden_encoder_dim_2 = 400
  12. hidden_decoder_dim = 400
  13. latent_dim = 20#(latent Variable)
  14. lam = 0
  15. def weight_variable(shape):
  16. initial = tf.truncated_normal(shape, stddev=0.001)
  17. return tf.Variable(initial)
  18. def bias_variable(shape):
  19. initial = tf.constant(0., shape=shape)
  20. return tf.Variable(initial)
  21. x = tf.placeholder("float", shape=[None, input_dim])##input x
  22. l2_loss = tf.constant(0.0)
  23. #encoder1 W b
  24. W_encoder_input_hidden_1 = weight_variable([input_dim,hidden_encoder_dim_1])##784*1000
  25. b_encoder_input_hidden_1 = bias_variable([hidden_encoder_dim_1])#1000
  26. l2_loss += tf.nn.l2_loss(W_encoder_input_hidden_1)
  27. # Hidden layer1 encoder
  28. hidden_encoder_1 = tf.nn.relu(tf.matmul(x, W_encoder_input_hidden_1) + b_encoder_input_hidden_1)##w*x+b
  29. #encoder2 W b
  30. W_encoder_input_hidden_2 = weight_variable([hidden_encoder_dim_1,hidden_encoder_dim_2])##1000*400
  31. b_encoder_input_hidden_2 = bias_variable([hidden_encoder_dim_2])#400
  32. l2_loss += tf.nn.l2_loss(W_encoder_input_hidden_2)
  33. # Hidden layer2 encoder
  34. hidden_encoder_2 = tf.nn.relu(tf.matmul(hidden_encoder_1, W_encoder_input_hidden_2) + b_encoder_input_hidden_2)##w*x+b
  35. W_encoder_hidden_mu = weight_variable([hidden_encoder_dim_2,latent_dim])##400*20
  36. b_encoder_hidden_mu = bias_variable([latent_dim])##20
  37. l2_loss += tf.nn.l2_loss(W_encoder_hidden_mu)
  38. # Mu encoder=+
  39. mu_encoder = tf.matmul(hidden_encoder_2, W_encoder_hidden_mu) + b_encoder_hidden_mu##mu_encoder:1*20(1*400 400*20)
  40. W_encoder_hidden_logvar = weight_variable([hidden_encoder_dim_2,latent_dim])##W_encoder_hidden_logvar:400*20
  41. b_encoder_hidden_logvar = bias_variable([latent_dim])#20
  42. l2_loss += tf.nn.l2_loss(W_encoder_hidden_logvar)
  43. # Sigma encoder
  44. logvar_encoder = tf.matmul(hidden_encoder_2, W_encoder_hidden_logvar) + b_encoder_hidden_logvar#logvar_encoder:1*20(1*400 400*20)
  45. # Sample epsilon
  46. epsilon = tf.random_normal(tf.shape(logvar_encoder), name='epsilon')
  47. # Sample latent variable
  48. std_encoder = tf.exp(0.5 * logvar_encoder)
  49. z = mu_encoder + tf.mul(std_encoder, epsilon)##z_mu+epsilon*z_std=z,as decoder's input;z:1*20
  50. W_decoder_z_hidden = weight_variable([latent_dim,hidden_decoder_dim])#W_decoder_z_hidden:20*400
  51. b_decoder_z_hidden = bias_variable([hidden_decoder_dim])##400
  52. l2_loss += tf.nn.l2_loss(W_decoder_z_hidden)
  53. # Hidden layer decoder
  54. hidden_decoder = tf.nn.relu(tf.matmul(z, W_decoder_z_hidden) + b_decoder_z_hidden)##hidden_decoder:1*400(1*20 20*400)
  55. W_decoder_hidden_reconstruction = weight_variable([hidden_decoder_dim, input_dim])##400*784
  56. b_decoder_hidden_reconstruction = bias_variable([input_dim])
  57. l2_loss += tf.nn.l2_loss(W_decoder_hidden_reconstruction)
  58. KLD = -0.5 * tf.reduce_sum(1 + logvar_encoder - tf.pow(mu_encoder, 2) - tf.exp(logvar_encoder), reduction_indices=1)##KLD
  59. x_hat = tf.matmul(hidden_decoder, W_decoder_hidden_reconstruction) + b_decoder_hidden_reconstruction##x_hat:1*784(reconstruction x)
  60. BCE = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(x_hat, x), reduction_indices=1)##sum cross_entropy
  61. loss = tf.reduce_mean(BCE + KLD)##average value
  62. regularized_loss = loss + lam * l2_loss
  63. loss_summ = tf.scalar_summary("lowerbound", loss)##Record the stored value of loss
  64. train_step = tf.train.AdamOptimizer(0.01).minimize(regularized_loss)##Optimization Strategy
  65. # add op for merging summary
  66. summary_op = tf.merge_all_summaries()
  67. # add Saver ops
  68. saver = tf.train.Saver()
  69. n_steps = int(1e5+1)##step:1000000
  70. batch_size = 100
  71. with tf.Session() as sess:
  72. summary_writer = tf.train.SummaryWriter('experiment',
  73. graph=sess.graph)##draw graph in tensorboard
  74. #if os.path.isfile("save/model.ckpt"):
  75. # print("Restoring saved parameters")
  76. #saver.restore(sess, "save/model.ckpt")
  77. #else:
  78. # print("Initializing parameters")
  79. sess.run(tf.initialize_all_variables())
  80. for step in range(1, n_steps):
  81. batch = mnist.train.next_batch(batch_size)
  82. feed_dict = {x: batch[0]}
  83. _, cur_loss, summary_str = sess.run([train_step, loss, summary_op], feed_dict=feed_dict)
  84. summary_writer.add_summary(summary_str, step)
  85. if step % 50 == 0:
  86. save_path = saver.save(sess, "save/model.ckpt")
  87. print("Step {0} | Loss: {1}".format(step, cur_loss))
  88. # save weights every epoch
  89. #if step % 100==0 :
  90. # generator.save_weights(
  91. #          'mlp_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
  92. #critic.save_weights(
  93. #        'mlp_critic_epoch_{0:03d}.hdf5'.format(epoch), True)
  94. ##Step 999900 | Loss: 114.41309356689453
  95. ##Step 999950 | Loss: 115.09370422363281
  96. ##Step 100000 | Loss: 124.32205200195312 ##Step 99700 | Loss: 116.05304718017578
  97. #1000 encode hidden layer=1 Step 950 | Loss: 159.3329620361328
  98. #1000 encode hidden layer=2 Step 950 | Loss: 128.81312561035156
  99. </span>

变分自编码器参考资料:

1.变分自编码器(Variational Autoencoder, VAE)通俗教程 – 邓范鑫——致力于变革未来的智能技术 http://www.dengfanxin.cn/?p=334

2.VAE variation inference变分推理 清爽介绍 http://mp.weixin.qq.com/s/9lNWkEEOk5vEkJ1f840zxA

3.深度学习变分贝叶斯自编码器(上下) https://zhuanlan.zhihu.com/p/25429082

变分自编码(VAE)及代码解读相关推荐

  1. 单指标时间序列异常检测——基于重构概率的变分自编码(VAE)代码实现(详细解释)

    1. 编写目的 不少论文都是基于VAE完成的异常检测,比如 Donut .Bagel.尽管 Donut 实现的模型很容易通过继承于重写父类方法的方式实现一个 VAE-baseline,并且 Bagel ...

  2. 变分自编码器VAE ——公式推导(含实现代码)

    目录 一.什么是变分自编码器 二.VAE的公式推导 三.重参数化技巧 一.什么是变分自编码器    在讲述VAE(variational auto-encoder)之前,有必要先看一下AE(auto- ...

  3. 变分自编码器VAE:一步到位的聚类方案

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 由于 VAE 中既有编码器又有解码器(生成器),同时隐变量分布又被近似编码为标准正态分布,因此 V ...

  4. VGAE(Variational graph auto-encoders)论文及代码解读

    一,论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二,论文解读 理论部分参考: Variational Graph Auto-Enc ...

  5. 变分自编码器VAE:这样做为什么能成?

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 话说我觉得我自己最近写文章都喜欢长篇大论了,而且扎堆地来.之前连续写了三篇关于 Capsule 的 ...

  6. 【Pytorch神经网络实战案例】14 构建条件变分自编码神经网络模型生成可控Fashon-MNST模拟数据

    1 条件变分自编码神经网络生成模拟数据案例说明 在实际应用中,条件变分自编码神经网络的应用会更为广泛一些,因为它使得模型输出的模拟数据可控,即可以指定模型输出鞋子或者上衣. 1.1 案例描述 在变分自 ...

  7. 【Pytorch神经网络实战案例】13 构建变分自编码神经网络模型生成Fashon-MNST模拟数据

    1 变分自编码神经网络生成模拟数据案例说明 变分自编码里面真正的公式只有一个KL散度. 1.1 变分自编码神经网络模型介绍 主要由以下三个部分构成: 1.1.1 编码器 由两层全连接神经网络组成,第一 ...

  8. 【神经网络】变分自编码大杂烩

    1.变分自编码 变分是数学上的概念,大致含义是寻求一个中间的函数,通过改变中间函数来查看目标函数的改变.变分推断是变分自编码的核心,那么变分推断是要解决的是什么问题?? 问题描述如下,假如我们有一批样 ...

  9. 4.keras实现--生成式深度学习之用变分自编码器VAE生成图像(mnist数据集和名人头像数据集)...

    1.VAE和GAN 变分自编码器(VAE,variatinal autoencoder)   VS    生成式对抗网络(GAN,generative adversarial network) 两者不 ...

  10. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

最新文章

  1. JAVA如何检测GC日志
  2. ThreadLocal使用
  3. ECCV 2020《Linguistic Structure Guided Context Modeling for Referring Image Segmentation》论文笔记
  4. 面试突击 | Redis 如何从海量数据中查询出某一个 Key?视频版
  5. 如何在Jupyter中运行R语言(两种解决方案)
  6. python如何搭建环境_Python基础环境如何搭建
  7. RabbitMQ中basicConsume、basicCancel、basicPublish方法
  8. 斐波那契数列-爬楼梯算法
  9. 算法笔记_面试题_11.正则表达式匹配
  10. GIS案例练习-----------第九天
  11. 2021-06-21>字体样式风格font
  12. 计算机考研初试/复试——软件工程
  13. 【opencv】 报错:C2065 “CV_COVAR_ROWS”、“CV_COVAR_NORMAL”、“CV_COVAR_SCALE”: 未声明的标识符、
  14. QQ免费企业邮箱申请配置
  15. 原生代码开发小米官网首页
  16. shiro集成jwt
  17. student dictionary
  18. 高可用架构:异地多活
  19. mysqld --defaults-file=/myfolder/my.cnf --defaults-extra-file=/myfolder2/my.cnf
  20. ubuntu kylin 分辨率不对

热门文章

  1. 大神论坛 利用活跃变量分析来去掉vmp的大部分垃圾指令
  2. Android SDK ADB命令行总结
  3. Oracle的Replace函数与translate函数详解与比较
  4. operator的理解
  5. 【CSS3学习笔记】16:边框图片效果
  6. 计算机绘画小房子教案,小班美术教案小房子
  7. google工具栏新览
  8. unity二維碼生成(新)
  9. jink下载出现:Failed to download RAMCode . Failed to prepare for programming .
  10. 潇潇六月雨 input file里的JQ change() 事件的只生效一次