Autoencoder是常见的一种非监督学习的神经网络。它实际由一组相对应的神经网络组成(可以是普通的全连接层,或者是卷积层,亦或者是LSTMRNN等等,取决于项目目的),其目的是将输入数据降维成一个低维度的潜在编码,再通过解码器将数据还原出来。因此autoencoder总是包含了两个部分,编码部分以及解码部分。编码部分负责将输入降维编码,解码部分负责让输出层通过潜在编码还原出输入层。我们的训练目标就是使得输出层与输入层之间的差距最小化。

我们会发现,有一定的风险使得训练出的AE模型是一个恒等函数,这是一个需要尽量避免的问题。

Autoencoder CNN 卷积自编码器

下面我们就用一个简单的基于mnist数据集的实现,来更好地理解autoencoder的原理。
首先是import相关的模块,定义一个用于对比显示输入图像与输出图像的可视化函数。

# Le dataset MNIST
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from tensorflow.keras.layers import Input,Dense, Conv2D, Conv2DTranspose, MaxPooling2D, Flatten, UpSampling2D, Reshape
from tensorflow.keras.models import Model,Sequential
import numpy as np
import matplotlib.pyplot as pltdef MNIST_AE_disp(img_in, img_out, img_idx):num_img = len(img_idx)plt.figure(figsize=(18, 4))for i, image_idx in enumerate(img_idx):# 显示输入图像ax = plt.subplot(2, num_img, i + 1)plt.imshow(img_in[image_idx].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# 显示输出图像ax = plt.subplot(2, num_img, num_img + i + 1)plt.imshow(img_out[image_idx].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.show()

加载数据并对mnist图像数据进行预处理,包括正则化以及将图片扩充成28,28,1的三维。

(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 正则化 [0, 255] à [0, 1]
x_train=x_train.astype('float32')/float(x_train.max())
x_test=x_test.astype('float32')/float(x_test.max())x_train=x_train.reshape(len(x_train),x_train.shape[1], x_train.shape[2], 1)
x_test=x_test.reshape(len(x_test),x_test.shape[1], x_test.shape[2], 1)

接下来就是自编码器神经网络的构建了。这里编码器与解码器都由两个卷积层构成,编码部分的池化层,对应了解码部分的upsampling层,以此来保证输入输出层的维度是一致的。

from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model# 编码
input_img = Input(shape=(28,28,1))
x = Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(input_img)
x = MaxPooling2D(pool_size=(2,2))(x)
encoded = Conv2D(filters=8, kernel_size=(3,3), activation='relu', padding='same')(x)# 解码
x = Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(encoded)
x = UpSampling2D(size=(2,2))(x)
decoded = Conv2D(filters=1,kernel_size=(3,3), activation='sigmoid', padding='same')(x)autoencodeur = Model(input_img, decoded)
autoencodeur.summary()
Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_10 (InputLayer)        [(None, 28, 28, 1)]       0
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 28, 28, 16)        160
_________________________________________________________________
max_pooling2d_17 (MaxPooling (None, 14, 14, 16)        0
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 14, 14, 8)         1160
_________________________________________________________________
conv2d_38 (Conv2D)           (None, 14, 14, 16)        1168
_________________________________________________________________
up_sampling2d_9 (UpSampling2 (None, 28, 28, 16)        0
_________________________________________________________________
conv2d_39 (Conv2D)           (None, 28, 28, 1)         145
=================================================================
Total params: 2,633
Trainable params: 2,633
Non-trainable params: 0
_________________________________________________________________

接下来就是AE神经网络的训练,与一般的神经网络不同的地方在于,在上述问题中训练自编码器时,输入输出都是同样的mnist图像,以保证在最后输出层能够无限接近输入层,损失降低到最小。

autoencodeur.compile(optimizer='Adam',loss='binary_crossentropy')
autoencodeur.fit(x_train, x_train, batch_size=256, epochs=5)

由于mnist数据集较为简单,在经过五个epoch之后AE模型基本收敛。

Epoch 1/5
235/235 [==============================] - 59s 250ms/step - loss: 0.3742
Epoch 2/5
235/235 [==============================] - 59s 250ms/step - loss: 0.0706
Epoch 3/5
235/235 [==============================] - 59s 250ms/step - loss: 0.0676
Epoch 4/5
235/235 [==============================] - 59s 251ms/step - loss: 0.0666
Epoch 5/5
235/235 [==============================] - 59s 249ms/step - loss: 0.0658

我们从数据集中随机选取10张图片,来对比一下通过自编码器后输入输出的图片的区别。

# 挑选十个随机的图片
num_images=10
np.random.seed(42)
random_test_images=np.random.randint(x_test.shape[0], size=num_images)
# 预测输出图片
decoded_img=autoencodeur.predict(x_test)
# 显示并对比输入与输出图片
MNIST_AE_disp(x_test, decoded_img, random_test_images)


我们从上述例子中可以看到,输出层与输入层相差无几,但是也并不是完全一致的,这说明了我们的自编码器运作正常且并没有生成一个恒等模型。接下来我们通过AE来构建一个去噪模型。

Autoencoder denoising 降噪自编码器

在这个部分中,我们将利用自编码器来实现对图片的降噪功能。首先我们生成一些带噪点的图片。

noise_factor = 0.4
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape) # clip 用于规定最小值和最大值,array中的值如果小于0则变为0 如果大于1则变为1
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)

直接使用上述已经训练好的模型,来看看当输入是带噪点的图像时,我们的卷积自编码器的输出是什么样的。

num_images=10
np.random.seed(42)
random_test_images_noisy=np.random.randint(x_test_noisy.shape[0], size=num_images) # list that contains the index of images chosen
print(random_test_images)
# On détermine l'image encodée et l'image décodée
decoded_img_noisy=autoencodeur.predict(x_test_noisy)
# visialisation
MNIST_AE_disp(x_test_noisy, decoded_img_noisy, random_test_images_noisy)


第一行对应的图片是我们手动生成的有噪点的图像,第二行对应的图片则是我们通过卷积自编码器后的输出图像。可以发现,输出的图像并没有完全一致,而是一定程度上已经去噪了,其实这可以进一步地佐证卷积神经网络处理带噪数据体现出的鲁棒性,即相较全连接层而言,对噪声的敏感程度更低。当然,这个降噪效果还不是很理想,因此我们创建一个新的autoencoder用于处理这一类降噪问题。

DAE = Model(input_img, decoded)
DAE.summary()
DAE.compile(optimizer='Adam', loss='binary_crossentropy')
DAE.fit(x_train_noisy, x_train, batch_size=256, epochs=5)

这个降噪用的自编码器,其架构与上述卷积自编码器相同,唯一有区别的地方在于训练时,我们的输入层变成了带噪图片,而输出层是没有噪声的图片,以此来达到降噪的训练目的。

同样的随机在数据集中选取图片进行对比,我们发现通过这个降噪自编码器后,图像的噪点明显减少了,而且与使用单纯的卷积自编码器不同的是,图像没有明显的钝化,清晰度很高。

在mnist数据集上的实现,同样可以给我们在其他的图片降噪问题上以启发,可以推测的是,更复杂的有噪图片通过类似的处理,也可以达到类似优秀的降噪效果。

autoencoder自编码器原理以及在mnist数据集上的实现相关推荐

  1. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  2. 使用mnist数据集_使用MNIST数据集上的t分布随机邻居嵌入(t-SNE)进行降维

    使用mnist数据集 It is easy for us to visualize two or three dimensional data, but once it goes beyond thr ...

  3. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  4. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  5. tensorflow(七)实现mnist数据集上图片的训练和测试

    本文使用tensorflow实现在mnist数据集上的图片训练和测试过程,使用了简单的两层神经网络,代码中涉及到的内容,均以备注的形式标出. 关于文中的数据集,大家如果没有下载下来,可以到我的网盘去下 ...

  6. keras笔记-mnist数据集上的简单训练

    学习了keras已经好几天了,之前一直拒绝使用keras,但是现在感觉keras是真的好用啊,可以去尝试一下啊. 首先展示的第一个代码,还是mnist数据集的训练和测试,以下是代码: from ker ...

  7. TensorFlow:实战Google深度学习框架(四)MNIST数据集识别问题

    第5章 MNIST数字识别问题 5.1 MNIST数据处理 5.2 神经网络的训练以及不同模型结果的对比 5.2.1 TensorFlow训练神经网络 5.2.2 使用验证数据集判断模型的效果 5.2 ...

  8. Paddle 环境中 使用LeNet在MNIST数据集实现图像分类

    简 介: 测试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例.基于可以搭建其他网络程序. 关键词: MNIST,Paddle,LeNet #mermaid-svg-FlRI ...

  9. [caffe(一)]使用caffe训练mnist数据集

    1.数据集的下载与转换 1)我们在mnist数据集上做测试,MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris ...

最新文章

  1. java 新建double对象_java-如何在BlueJ“创建对象”对话框中输入...
  2. 在SQL SERVER中实现Split功能的函数,并在存储过程中使用
  3. [BUUCTF-pwn]——pwn2_sctf_2016
  4. 基于html5的旅游交流系统,基于HTML5的旅游移动导览系统的研究与实现
  5. Proxmark3教程1:小白如何用PM3破解复制M1全加密门禁IC卡
  6. Python基础——文件拷贝(从手动实现到shutil的使用)
  7. springboot读取resources目录下文件
  8. 巨美国际教您如何开好网店?
  9. 连续翻页浏览器面临的共同问题
  10. start request repeated too quickly for filebeat.service
  11. 统计学知识:相关系数
  12. Java8-排序方法(正序、倒序)
  13. 键盘特殊符号输入小技巧
  14. win7局域网自建ftp服务器,win7系统搭建FTp服务器局域网内传输文件的解决教程
  15. C++ typename详解
  16. 下拉推广系统立择火星推荐_下拉词推广权威易速达
  17. 链化未来共识协议详解(下)
  18. Log4J使用说明书
  19. CANON废墨清零方法
  20. AVS+标准应用现状

热门文章

  1. css设置a连接禁用样式_使用CSS禁用链接
  2. stl中copy()函数_std :: rotate_copy()函数以及C ++ STL中的示例
  3. 车牌识别与计算机编程,基于MATLAB的车牌识别程序详解.ppt
  4. MGraph图(代码、分析、汇编)
  5. python安全攻防---scapy使用
  6. python的opencv模块_OpenCV Python - 没有名为cv2的模块(再次)
  7. Linux系统编程----7(信号集,信号屏蔽,信号捕捉)
  8. 基于单链表的生产者消费者问题
  9. 【Linux系统编程学习】 动态库的制作与使用
  10. 【大牛疯狂教学】mysqlinnodb和myisam