TensorFlow实现去噪自编码器(Denoising Autoencoder)

  • 去噪自编码器(Denoising Autoencoder, DAE)
    • DAE模型架构
  • DAE实现
    • 数据预处理
    • 模型构建与模型训练
  • 效果展示

去噪自编码器(Denoising Autoencoder, DAE)

在介绍去噪自编码器 (Denoising Autoencoder, DAE) 之前,首先介绍下DAE的一种使用场景示例,当我们在夜晚拍照时,或者其他黑暗环境时,我们的照片总是被大量的噪点所充斥,严重影响了图像质量,而 DAE 的目的就是用来去除这些图像中的噪声。为了更好的讲解 DAE,使用简单的 MNIST 数据集进行演示,以将我们的重心放在有关 DAE 的知识上。如下图所示,显示了三组 MNIST 数字。每组的顶行是原始图像 (Original Images);中间的行显示 DAE 的输入 (Noised Images),这些输入是被噪声破坏的原始图像,当噪声过多时,我们将很难读懂被破坏的数字;最后一行显示DAE的输出 (Denoised Images)。

Tips:如果对于自编码器还不是很了解的小伙伴,可以参考自编码器模型详解与实现(采用tensorflow2.x实现)。

接下来就让我们实际构建一个 DAE,以消除图像中的噪声。

DAE模型架构

根据 DAE 的介绍可以将输入定义为:
x=xorig+noisex = x_{orig} + noisex=xorig​+noise
其中 xorigx_{orig}xorig​ 表示被噪声 noisenoisenoise 破坏的原始 MNIST 图像,编码器的目的是学习潜矢量 zzz。DAE的损失函数表示为:
L(xorig,x~)=MSE=1m∑i=1i=m(xorigi−x~i)2\mathcal L(x_{orig}, \tilde x)=MSE=\frac 1 m \sum_{i=1} ^{i=m}(x_{orig_i}-\tilde x_i)^2 L(xorig​,x~)=MSE=m1​i=1∑i=m​(xorigi​​−x~i​)2

其中,mmm 是输出的维度,例如在MNIST数据集中,m=width×height×channels=28×28×1=784m=width × height×channels=28 × 28 × 1 = 784m=width×height×channels=28×28×1=784。xorigix_{orig_i}xorigi​​ 和 xix_ixi​ 分别是 xorigx_{orig}xorig​ 和 x~\tilde xx~ 中的元素。

DAE实现

数据预处理

为了实现DAE,首先需要构造训练数据集,输入数据是添加噪声的 MNIST 数字,训练输出数据是原始的干净 MNIST 数字。添加的噪声需要满足高斯分布,均值 μ=0.5μ = 0.5μ=0.5,标准差 σ=0.5σ = 0.5σ=0.5。由于添加随机噪声可能会产生小于0或大于1的无效像素值,因此需要将像素值裁剪为[0.0,1.0]范围内。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
from PIL import Image# 数据加载
(x_train,_),(x_test,_) = keras.datasets.mnist.load_data()# 数据预处理
image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1,image_size,image_size,1])
x_test = np.reshape(x_test,[-1,image_size,image_size,1])
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.# 产生高斯分布的噪声
noise = np.random.normal(loc=0.5,scale=0.5,size=x_train.shape)
x_train_noisy = x_train + noise
noise = np.random.normal(loc=0.5,scale=0.5,size=x_test.shape)
x_test_noisy = x_test + noise# 将像素值裁剪为[0.0,1.0]范围内
x_train_noisy = np.clip(x_train_noisy,0.0,1.0)
x_test_noisy = np.clip(x_test_noisy,0.0,1.0)

模型构建与模型训练

# 超参数
input_shape = (image_size,image_size,1)
batch_size = 32
kernel_size = 3
latent_dim = 16
layer_filters = [32,64]"""
模型
"""
#编码器
inputs = keras.layers.Input(shape=input_shape,name='encoder_input')
x = inputs
for filters in layer_filters:x = keras.layers.Conv2D(filters=filters,kernel_size=kernel_size,strides=2,activation='relu',padding='same')(x)
shape = keras.backend.int_shape(x)x = keras.layers.Flatten()(x)
latent = keras.layers.Dense(latent_dim,name='latent_vector')(x)
encoder = keras.Model(inputs,latent,name='encoder')
encoder.summary()# 解码器
latent_inputs = keras.layers.Input(shape=(latent_dim,),name='decoder_input')
x = keras.layers.Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
x = keras.layers.Reshape((shape[1],shape[2],shape[3]))(x)
for filters in layer_filters[::-1]:x = keras.layers.Conv2DTranspose(filters=filters,kernel_size=kernel_size,strides=2,padding='same',activation='relu')(x)
outputs = keras.layers.Conv2DTranspose(filters=1,kernel_size=kernel_size,padding='same',activation='sigmoid',name='decoder_output')(x)
decoder = keras.Model(latent_inputs,outputs,name='decoder')
decoder.summaryautoencoder = keras.Model(inputs,decoder(encoder(inputs)),name='autoencoder')
autoencoder.summary()# 模型编译与训练
autoencoder.compile(loss='mse',optimizer='adam')
autoencoder.fit(x_train_noisy,x_train,validation_data=(x_test_noisy,x_test),epochs=10,batch_size=batch_size)# 模型测试
x_decoded = autoencoder.predict(x_test_noisy)rows,cols = 3,9
num = rows * cols
imgs = np.concatenate([x_test[:num],x_test_noisy[:num],x_decoded[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs,rows,axis=1))
imgs = imgs.reshape((rows * 3,-1,image_size,image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = (imgs * 255).astype(np.uint8)
plt.figure()
plt.axis('off')
plt.imshow(imgs,interpolation='none',cmap='gray')
plt.show()

效果展示


如上图所示,当噪声水平从 σ=0.5σ=0.5σ=0.5 增加到 σ=0.75σ=0.75σ=0.75 和 σ=1.0σ=1.0σ=1.0 时,DAE 具有一定的鲁棒性,可以较好的恢复出原始图像。但是,在 σ=1.0σ=1.0σ=1.0 时,某些数字,没有被正确地恢复。

TensorFlow实现去噪自编码器(Denoising Autoencoder)相关推荐

  1. 正则自编码器之去噪自编码器

    图1.自编码器的一般结构 传统自编码器通过最小化如下目标:                                                   (公式1) 公式1中L是一个损失函数,惩 ...

  2. TensorFlow 实现深度神经网络 —— Denoising Autoencoder

    完整代码请见 models/DenoisingAutoencoder.py at master · tensorflow/models · GitHub: 1. Denoising Autoencod ...

  3. 【theano-windows】学习笔记十三——去噪自编码器

    前言 上一章节学习了卷积的写法,主要注意的是其实现在theano.tensor.nnet和theano.sandbox.cuda.dnn中都有对应函数实现, 这一节就进入到无监督或者称为半监督的网络构 ...

  4. SDAE-stacked denoised autoencoder (堆栈去噪自编码器)

    堆栈自编码器 Stacked AutoEncoder_浮生了大白的博客-CSDN博客_堆栈自编码器 为什么稀疏自编码器很少见到多层的? - 知乎 Based on blog which links w ...

  5. tensorflow 卷积、反卷积形式的去噪自编码器

    tensorflow 卷积.反卷积形式的去噪自编码器 对于去噪自编码器,网上好多都是利用全连接神经网络进行构建,我自己写了一个卷积.反卷积形式的去噪自编码器,其中的参数调优如果有兴趣的话,可以自行修改 ...

  6. Tensorflow Day19 Denoising Autoencoder

    今日目標 了解 Denoising Autoencoder 訓練 Denoising Autoencoder 測試不同輸入情形下的 Denoising Autoencoder 表現 Github Ip ...

  7. [自编码器:理论+代码]:自编码器、栈式自编码器、欠完备自编码器、稀疏自编码器、去噪自编码器、卷积自编码器

    写在前面 因为时间原因本文有些图片自己没有画,来自网络的图片我尽量注出原链接,但是有的链接已经记不得了,如果有使用到您的图片,请联系我,必注释. 自编码器及其变形很多,本篇博客目前主要基于普通自编码器 ...

  8. 【TensorFlow-windows】(二) 实现一个去噪自编码器

    主要内容: 1.自编码器的TensorFlow实现代码(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_6 ...

  9. 论文阅读笔记-Gated relational stacked denoising autoencoder with localized author embedding for

    论文题目:Gated relational stacked denoising autoencoder with localized author  embedding for global cita ...

最新文章

  1. hadoop上的pageRank算法
  2. 刘忠范院士:新型研发机构建设成了口号
  3. PostgreSQL — 基于 Recovery 流复制的数据备份
  4. 使用sshfs挂载linux远程服务器目录到windows
  5. green ethernet
  6. [skill] C与C++对于类型转换的验证
  7. a king读后感 love of the_A华语电影高清合集
  8. python前后台tcp/udp通讯示例
  9. 深入理解Python对象(源码深度解析)
  10. SYN flood***的原理及其防御 (一)
  11. 少儿编程之Scratch入门汇总篇
  12. python base_Python base(一)
  13. Deepest Root(dfs深度优先遍历)
  14. ISE在win10中闪退解决方法以及ISE14.7安装包
  15. WIN10系统从睡眠状态唤醒后电脑变卡顿
  16. Android笔记本处理器,惠普或推Android笔记本:配Tegra处理器
  17. HTML超链接文字加粗,Markdown语法之--标题/注释/超链接/下划线/图片/代码/贯穿线/斜体加粗/列表,使你的文本更丰富...
  18. 物联卡的套餐类型有哪些
  19. swi 指令能用在C语言吗,SWI指令---软件中断实例详解
  20. Python爬虫 —— 以北京天气数据爬取为例

热门文章

  1. sql 的 DATE_FORMATE()函数
  2. nodejs初探(四)实现一个多人聊天室
  3. 手机网页 复制信息方法 免费短信
  4. [转载] Python进程——multiprocessing.Event()|Barrier()
  5. MyBatis-Spring-Boot 使用总结
  6. 【现代软件工程】第一次作业——词频统计
  7. Winfrom窗体无法关闭问题--检查是否存在重写
  8. Oracle 11g vs 12c 内存、优化器等默认参数对比
  9. 使用EDITPLUS编写C#控制台应用程序
  10. caffe学习日记--Lesson2:再看caffe的安装和使用、学习过程