变分自编码(VAE)的东西,将一些理解记录在这,不对的地方还请指出。

在论文《Auto-Encoding Variational Bayes》中介绍了VAE。

训练好的VAE可以用来生成图像。

在Keras 中提供了一个VAE的Demo:variational_autoencoder.py

'''This script demonstrates how to build a variational autoencoder with Keras.
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import normfrom keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
from keras.utils.visualize_util import plot
import syssaveout = sys.stdout
file = open('variational_autoencoder.txt','w')
sys.stdout = filebatch_size = 100
original_dim = 784   #28*28
latent_dim = 2
intermediate_dim = 256
nb_epoch = 50
epsilon_std = 1.0#my tips:encoding
x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)#my tips:Gauss sampling,sample Z
def sampling(args): z_mean, z_log_var = argsepsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,std=epsilon_std)return z_mean + K.exp(z_log_var / 2) * epsilon# note that "output_shape" isn't necessary with the TensorFlow backend
# my tips:get sample z(encoded)
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)#my tips:loss(restruct X)+KL
def vae_loss(x, x_decoded_mean):#my tips:loglossxent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)#my tips:see paper's appendix Bkl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)return xent_loss + kl_lossvae = Model(x, x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=vae_loss)# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data(path='mnist.pkl.gz')x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))vae.fit(x_train, x_train,shuffle=True,nb_epoch=nb_epoch,verbose=2,batch_size=batch_size,validation_data=(x_test, x_test))# build a model to project inputs on the latent space
encoder = Model(x, z_mean)# display a 2D plot of the digit classes in the latent space
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the prior of the latent space is Gaussian
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))for i, yi in enumerate(grid_x):for j, xi in enumerate(grid_y):z_sample = np.array([[xi, yi]])x_decoded = generator.predict(z_sample)digit = x_decoded[0].reshape(digit_size, digit_size)figure[i * digit_size: (i + 1) * digit_size,j * digit_size: (j + 1) * digit_size] = digitplt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()plot(vae,to_file='variational_autoencoder_vae.png',show_shapes=True)
plot(encoder,to_file='variational_autoencoder_encoder.png',show_shapes=True)
plot(generator,to_file='variational_autoencoder_generator.png',show_shapes=True)sys.stdout.close()
sys.stdout = saveout

代码实验MNIST手写数字数据集

这个Demo中,VAE的loss函数最小化了两项之和:

一是重构数据x_decoded_mean与原始数据X的binary_crossentropy

二是近似后验与真实后验的KL散度,至于KL散度为何简化成代码中的形式,看论文《Auto-Encoding Variational Bayes》中的附录B有证明。

Demo中VAE的形状如下:

实验中的编码器形状:

代码中将编码得到的均值U可视化结果:

同一颜色为同一数字,发现编码后的二维U值聚类效果很好

实验中的生成器形状:

可将从二维高斯分布中随机采样得到的Z,解码成手写数字图片

代码中将解码得到的图像可视化:

VAE【keras实现】相关推荐

  1. Tensorflow Auto-encoder + VAE 实战

    让我们来康康作为base_line的Auto_Encoder import os import tensorflow as tf import numpy as np from tensorflow ...

  2. 变分自编码器(VAE)详解与实现(tensorflow2.x)

    变分自编码器(VAE)详解与实现(tensorflow2.x) VAE介绍 VAE原理 变分推理 VAE核心方程 优化方式 重参数化技巧(Reparameterization trick) VAE实现 ...

  3. 深度学习之自编码器(5)VAE图片生成实战

    深度学习之自编码器(5)VAE图片生成实战 1. VAE模型 2. Reparameterization技巧 3. 网络训练 4. 图片生成 VAE图片生成实战完整代码  本节我们基于VAE模型实战F ...

  4. 各种 AI 数据增强方法,都在这儿了

    来源 | 算法进阶 责编 | 寇雪芹 头图 | 下载于视觉中国 数据.算法.算力是人工智能发展的三要素.数据决定了Ai模型学习的上限,数据规模越大.质量越高,模型就能够拥有更好的泛化能力. 然而在实际 ...

  5. 【机器学习】一文归纳AI数据增强之法

    数据.算法.算力是人工智能发展的三要素.数据决定了Ai模型学习的上限,数据规模越大.质量越高,模型就能够拥有更好的泛化能力.然而在实际工程中,经常有数据量太少(相对模型而言).样本不均衡.很难覆盖全部 ...

  6. 【机器学习基础】一文归纳AI数据增强之法

    数据.算法.算力是人工智能发展的三要素.数据决定了Ai模型学习的上限,数据规模越大.质量越高,模型就能够拥有更好的泛化能力.然而在实际工程中,经常有数据量太少(相对模型而言).样本不均衡.很难覆盖全部 ...

  7. TensorFlow2-自编码器

    TensorFlow2自编码器 简介 深度学习中也有很多无监督学习的算法,其中,自编码器是最为典型的代表.事实上,人工标注的数据毕竟是少数,互联网每天都在产生海量的无标签数据,如何利用这些数据就是无监 ...

  8. Tensorflow实现变分自编码器

    自编码器一般功能是压缩数据.但是变分自编码器(variational autoencoder,VAE)的主要作用却是可以生成数据.变分自编码器的这种功能可以应用到很多领域,从生成手写文字到合成音乐. ...

  9. 【深度学习】用变分自编码器生成图像和生成式对抗网络

    目录 问题描述: 代码展示: VAE代码段 GAN部分(仅供参考) 运行截图: 参考: 问题描述: 从图像的潜在空间中采样,并创建全新图像或编辑现有图像,这是目前最流行也是最成 功的创造性人工智能应用 ...

最新文章

  1. 自动驾驶解决方案架构
  2. Linux五种IO模型性能分析
  3. 【Android Gradle 插件】ProductFlavor 配置 ( ProductFlavor#manifestPlaceholders 清单文件占位符配置 )
  4. 微服务框架下的思维变化-OSS.Core基础思路
  5. GCC 编译报错:程序中有游离的 \357’ \273’ \277’ 等
  6. java形状函数_java基础:10.4 Java FX之形状
  7. 利用辗转相除法求两个数的最大公约数
  8. 如何将Python程序打包成linux可执行文件
  9. CANTest软件安装成功经验
  10. QML QtLocation地图应用学习-4:行政区划
  11. vue 点击图标旋转
  12. C++ 函数模板 实例化和具体化
  13. 隐函数求导(一元和二元)
  14. 【2017.11.30】3. Longest Substring Without Repeating Characters-最长字串不重复字符
  15. 中心对称图形——平行四边形·复习整理
  16. Vue生命周期及store
  17. MySQL空间函数——ST_AsText走过的坑
  18. 一个潜藏4年之久的内核bug
  19. 手机adb调试出现Not running as root. Try“adb root“ first.
  20. 程序员的自我进化:补上最短的那块情商木板

热门文章

  1. pthread_cond_t
  2. 第二章 Qt Widgets项目的创建、运行和发布的过程
  3. c++程序设计中的多态与虚函数知识点
  4. matlab实现层次分析法
  5. Windows编程—杀死指定路径程序文件的进程
  6. jqgrid treegrid 重新加载数据
  7. 编程学习记录13:Oracle数据库,表的查询
  8. 「译」JUnit 5 系列:环境搭建
  9. js获取几个月前,几周前时间。
  10. K线理论--单根K线形态