VAE学习笔记

  1. 普通的编码器可以将图像这类信息编码成为特征向量.

  2. 但通常这些特征向量不具有空间上的连续性.

  3. VAE(变分自编码器)可以将图像信息编码成为具有空间连续性的特征向量.

  4. 方法是向编码器和解码器中加入统计信息,即特征向量代表的的是一个高斯分布,强迫特征向量服从高斯分布.

  5. 编码器是将图片信息编码成为一个高斯分布.

  6. 解码器则是从特征空间中进行采样,再经过全连接层,反卷积层,卷积层等恢复成一张与输入图片大小相等的图片.

  7. 损失函数有两个目标:即(1)拟合原始图片以及(2)使得特征空间具有良好的结构及降低过拟合.因此,我们的损失函数由两部分构成.其中第二部分需要使得编码出的正态分布围绕在标准正态分布周围.


实现代码

import keras
from keras import layers
from keras import backend as K
from keras.models import Model
import numpy as npimg_shape = (28,28,1)
batch_size = 16
latent_dim = 2##### 图片编码器部分
input_img = keras.Input(shape = img_shape)x = layers.Conv2D(32,3,padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64,3,padding = 'same', activation = 'relu')(x)shape_before_flatting = K.int_shape(x)x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)#输入图像最终被编码为如下两个参数
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
# 需要注意的是 z_log_var = 2log(sigma)
##### 编码器部分结束##### 采样函数,用于在给定的正态分布中进行采样,这也就是编码器加入统计信息的地方.
def sampling(args):z_mean, z_log_var = argsepsilon = K.random_normal(shape = (K.shape(z_mean)[0],latent_dim),mean = 0.,stddev = 1.)return z_mean + K.exp(0.5*z_log_var) * epsilon##### VAE解码器部分
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape_before_flatting[1:]),activation='relu')(decoder_input)
x = layers.Reshape(shape_before_flatting[1:])(x)
x = layers.Conv2DTranspose(32,3,padding='same',activation = 'relu',strides = (2,2))(x)
x = layers.Conv2D(1,3,padding='same',activation='sigmoid')(x)
##### 到这里x被恢复成为一张图片##### 下面两句话将编码器和解码器通过上采样函数连接到了一起
decoder = Model(decoder_input,x)
z = layers.Lambda(sampling)([z_mean,z_log_var])
z_decoder = decoder(z)###### 自定义损失函数层
def vae_loss(y_true,y_pred,e = 0.1):x = K.flatten(y_true)z_decoded = K.flatten(y_pred)xent_loss = keras.metrics.binary_crossentropy(x,z_decoded)kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var),axis = -1)return K.mean(xent_loss + kl_loss)from keras.datasets import mnistvae = Model(input_img,z_decoder)
vae.compile(optimizer = 'rmsprop',loss = vae_loss)
vae.summary()##### 训练模型
from keras.datasets import mnist
(x_train,_),(x_test,y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))vae.fit(x = x_train,y = x_train,shuffle = True,epochs = 10,batch_size = batch_size,validation_data = (x_test,x_test))##### 在特征空间进行连续采样,观察输出图片
import matplotlib.pyplot as plt
from scipy.stats import norm
n = 24
digit_size = 28
figure = np.zeros((digit_size*n,digit_size*n))
grid_x = norm.ppf(np.linspace(0.02,0.98,n))
grid_y = norm.ppf(np.linspace(0.02,0.98,n))print(batch_size)for i,yi in enumerate(grid_x):for j,xi in enumerate(grid_y):z_sample = np.array([[xi,yi]])z_sample = np.tile(z_sample,1).reshape(1, 2)x_decoded = decoder.predict(z_sample, batch_size = 1)digit = x_decoded[0].reshape(digit_size,digit_size)figure[i*digit_size:(i+1)*digit_size,j*digit_size:(j+1)*digit_size] = digitplt.figure(figsize = (15,15))
plt.imshow(figure,cmap = 'Greys_r')
plt.show()

输出结果

VAE(变分自编码器)学习笔记相关推荐

  1. PyTorch 实现 VAE 变分自编码器 含代码

    编码器 自编码器 自编码器网络结构图 线性自编码器代码如下: 卷积自编码器代码如下: 变分自编码器 变分自编码器网络结构图 变分自编码器代码如下: Ref 自编码器 自编码器网络结构图 线性自编码器代 ...

  2. VAE 变分自编码器

    收集了几篇文章,介绍VAE变分自编码器,如下: 1.[干货]深入理解变分自编码器 - 知乎转载自:机器学习研究组订阅 原文链接:[干货]深入理解变分自编码器[导读]自编码器是一种非常直观的无监督神经网 ...

  3. 通俗易懂——VAE变分自编码器原理

    变分自编码器(Variational Auto Encoder, VAE) 李宏毅机器学习笔记.转载请注明出处. 自编码器(Autoencoder): Autoencoder = Encoder + ...

  4. VAE变分自编码器实现

    变分自编码器(VAE)组合了神经网络和贝叶斯推理这两种最好的方法,是最酷的神经网络,已经成为无监督学习的流行方法之一. 变分自编码器是一个扭曲的自编码器.同自编码器的传统编码器和解码器网络一起,具有附 ...

  5. 入门到精通!珍藏资源!VAE变分自编码器

    过去虽然没有细看,但印象里一直觉得变分自编码器(Variational Auto-Encoder,VAE)是个好东西.趁着最近看概率图模型的三分钟热度,我决定也争取把 VAE 搞懂.    于是乎照样 ...

  6. 深入理解VAE(变分自编码器)

    原文地址:https://pan.baidu.com/s/1LNolV-_SZcEhV0vz2RkDRQ : 本文进行翻译和总结. VAE VAE是两种主要神经网络生成模型中的一种,另一种典型的方法是 ...

  7. 绝对旋转编码器学习笔记(基本原理,与PC通信等,不定期更新中)

    目录 前言 1.编码器的定义 2.编码器的分类 3.绝对式编码器 3.1工作原理 3.2多圈绝对值编码器 3.3绝对式编码器的特点 3.4绝对式编码器的输出 4.编码器与计算机的通信 4.1绝对值信号 ...

  8. 【AI绘图学习笔记】变分自编码器VAE

    无监督学习之VAE--变分自编码器详解 机器学习方法-优雅的模型(一):变分自编码器(VAE) 无需多言,看这两篇文章即可.本文主要是总结一下我在看这篇文章和其他视频时没能看懂的部分解读. 文章目录 ...

  9. 【论文阅读-3】生成模型——变分自编码器(Variational Auto-Encoder,VAE)

    [论文阅读]生成模型--变分自编码器 1. VAE设计思路:从PCA到VAE 1.1 PCA 1.2 自编码器(Auto-Encoder, AE) 1.3 从AE到VAE 2. VAE模型框架 2.1 ...

最新文章

  1. linux 读取大量图片 内存,10 张图帮你搞定 TensorFlow 数据读取机制
  2. 四月青少年编程组队学习(图形化四级)Task05
  3. 磁盘分区标为活动的方法及取消磁盘分区标为活动的方法
  4. eeglab中文教程系列(10)-利用光谱选项绘制ERP图像
  5. mysql date_format 按不同时间单位进行分组统计
  6. 论文,范围管理(2017上)
  7. ios markdown 解析_Shortcuts 教程:正则表达式修改 Markdown 链接
  8. iOS-开发记录-UIView属性
  9. x86 vs x64
  10. 如何联网获取北京时间
  11. Expression Designer系列工具汇总 [转载]
  12. windows中安装zookeeper
  13. RK 利用SARADC 来做多个按键
  14. pt100温度传感器c语言,pt100测温程序-LCD1602
  15. RTSP鉴权认证之基础认证和摘要认证
  16. Linux安全手册(转载)
  17. java循环树_for循环输出树木的形状【java】
  18. 射频电路中三种基本接收机结构
  19. 计算机专业技术面试题
  20. 2021kali系列 -- 破解无线密码

热门文章

  1. springmvc如何使用视图解析器_SpringMVC工作原理
  2. java 监听文件内容_java 监听文件内容变化
  3. linux如何实现网络高级编程,嵌入式Linux网络编程之:网络高级编程-嵌入式系统-与非网...
  4. qt制作一个画板_如何直接用Sketch制作动画|Sketch插件|
  5. python while循环true_Python while循环,pause while not,true时继续?
  6. leetcode-445. 两数相加 II
  7. 78. 子集022(回溯法)
  8. [mybatis]typeHandlers日期类型的处理
  9. [JavaWeb-MySQL]DCL管理用户,授权
  10. 交通标志识别项目教程