深度学习之自编码器(2)Fashion MNIST图片重建实战

  • 1. Fashion MNIST数据集
  • 2. 编码器
  • 3. 解码器
  • 4. 自编码器
  • 5. 网络训练
  • 6. 图片重建
  • 完整代码

 自编码器算法原理非常简单,实现方便,训练也较稳定,相对于PCA算法,神经网络的强大表达能力可以学习到输入的高层抽象的隐层特征向量 z\boldsymbol zz,同时也能够基于 z\boldsymbol zz重建出输入。这里我们基于Fashion MNIST数据集进行图片重建实战。

1. Fashion MNIST数据集

 Fashion MNIST是一个定位比MNIST图片识别问题稍微复杂的数据集,它的设定与MNIST几乎完全一样,包含了10类不同类型的衣服、鞋子、包等灰度图片,图片大小为28×2828\times2828×28,共70000张图片,其中60000张用于训练集,10000张用于测试集,如下图所示,每行都是一种类别图片。

Fashion MNIST数据集

 可以看到,Fashion MNIST除了图片内容与MNIST不一样,其它设定都相同,大部分情况可以直接替换掉原来基于MNIST训练的算法代码,而不需要额外修改。由于Fashion MNIST图片识别相对于MNIST图片更难,因此可以用于测试稍微复杂的算法性能。

 在TensorFlow中,加载Fashion MNIST数据集同样非常方便,利用keras.datasets.fashion_mnist.load_data()函数即可在线下载、管理和加载。由于在线加载十分缓慢,我使用了本地加载。数据加载和测试代码如下:

import os
import tensorflow as tf
import numpy as np
import sslfrom Chapter12.Fashion_MNIST_dataload import get_dataos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
ssl._create_default_https_context = ssl._create_unverified_contextbatchsz = 512# 加载Fashion MNIST图片数据集
(x_train, y_train), (x_test, y_test) = get_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
# 打印训练集和测试集的shape
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

运行结果如下所示:

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

其中,数据加载函数get_data()实现如下:

import numpy as np
import gzipdef get_data():# 文件获取train_image = r"/Users/XXX/.keras/datasets/fashion_mnist/train-images-idx3-ubyte.gz"test_image = r"/Users/XXX/.keras/datasets/fashion_mnist/t10k-images-idx3-ubyte.gz"train_label = r"/Users/XXX/.keras/datasets/fashion_mnist/train-labels-idx1-ubyte.gz"test_label = r"/Users/XXX/.keras/datasets/fashion_mnist/t10k-labels-idx1-ubyte.gz"  # 文件路径paths = [train_label, train_image, test_label, test_image]with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)

参考:
[1] fashion-mnist简介和使用及下载
[2] 从本地加载FASHION MNIST数据集并输入到模型进行训练

2. 编码器

 我们利用编码器将输入图片x∈R784\boldsymbol x\in R^{784}x∈R784降维到较低维度的隐藏向量:h∈R20\boldsymbol h\in R^{20}h∈R20,并基于隐藏向量h\boldsymbol hh利用解码器重建图片,自编码器模型如下图所示,编码器由3层全连接层网络组成,输出节点数分别为256、128、20,解码器同样由3层全连接网络组成,输出节点数分别为128、256、784。

Fashion MNIST自编码器网络结构

 首先是编码器子网络的实现。利用3层的神经网络将长度为784的图片向量数据一次降维到256、128,最后降维到h_dim维度,每层使用ReLU激活函数,最后一层不使用激活函数。代码如下:

# 创建Encoders网络,实现在自编码器类的初始化函数中
self.encoder = Sequential([layers.Dense(256, activation=tf.nn.relu),layers.Dense(128, activation=tf.nn.relu),layers.Dense(h_dim)
])

3. 解码器

 然后再来创建解码器子网络,这里基于隐藏向量h_dim一次升维到128、256、784长度,除最后一层,激活函数使用ReLU函数。解码器的输出为784长度的向量,代表了打平后的28×2828\times2828×28大小图片,通过Reshape操作即可恢复为图片矩阵。代码如下:

# 创建Decoders网络
self.decoder = Sequential([layers.Dense(128, activation=tf.nn.relu),layers.Dense(256, activation=tf.nn.relu),layers.Dense(784)
])

4. 自编码器

 上述的编码器和解码器2个子网络均实现在自编码器类AE中,我们在初始化函数中同时创建这两个子网络。代码如下:

class AE(keras.Model):def __init__(self):super(AE, self).__init__()# 创建Encoders网络,实现在自编码器类的初始化函数中self.encoder = Sequential([layers.Dense(256, activation=tf.nn.relu),layers.Dense(128, activation=tf.nn.relu),layers.Dense(h_dim)])# 创建Decoders网络self.decoder = Sequential([layers.Dense(128, activation=tf.nn.relu),layers.Dense(256, activation=tf.nn.relu),layers.Dense(784)])

 接下来将前向传播过程实现在call函数中,输入图片首先通过encoder子网络得到隐藏向量h,再通过decoder得到重建图片。一次调用编码器和解码器的前向传播函数即可,代码如下:

def call(self, inputs, training=None):# [b, 784] => [b, 10]h = self.encoder(inputs)# [b, 10] => [b, 784]x_hat = self.decoder(h)return x_hat

5. 网络训练

 自编码器的训练过程与分类器的基本一致,通过误差函数计算出重建向量xˉ\bar\boldsymbol xxˉ与原始输入x\boldsymbol xx之间的距离,再利用TensorFlow的自动求导机制同时求出encoder和decoder的梯度,循环更新即可。

 首先创建自编码器实例和优化器,并设置合适的学习率。例如:

# 创建网络对象
model = AE()
# 指定输入大小
model.build(input_shape=(None, 784))
# 打印网络信息
model.summary()
# 创建优化器,并设置学习率
optimizer = tf.optimizers.Adam(lr=lr)

 这里固定训练100个Epoch,每次通过前向计算获得重建图片向量,并利用tf.nn.sigmoid_cross+entropy_with_logits损失函数计算城建图片与原始图片直接的误差,实际上利用MSE误差函数也是可行的。代码如下:

for epoch in range(100):  # 训练100个Epochfor step, x in enumerate(train_db):  # 遍历训练集#  打平,[b, 28, 28] => [b, 784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits = model(x)# 计算重建图片与输入之间的损失函数rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)# 计算均值rec_loss = tf.reduce_mean(rec_loss)# 自动求导,包含了2个子网络的梯度grads = tape.gradient(rec_loss, model.trainable_variables)# 自动更新,同时更新2个子网络optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:# 间隔性打印训练误差print(epoch, step, float(rec_loss))

6. 图片重建

 与分类问题不同的是,自编码器的模型性能一般不好量化评价,尽管L值可以在一定程度上代表网络的学习效果,但我们最终希望获得还原度较高、样式较丰富的重建样本。因此一般需要更具具体问题来讨论自编码器的学习效果,比如对于图片重建,一般依赖于人工主管评价图片生成的质量,或利用某些图片逼真度计算方法(如Inception Score和Frechet Inception Distance)来辅助评估。

 为了测试图片重建效果,我们把数据集切分为训练集与测试集,其中测试集不参与训练。我们从测试集中随机采样测试图片x∈Dtest\boldsymbol x\in \mathbb{D}^{test}x∈Dtest,经过自编码器计算得到重建后的图片,然后将真实图片与重建图片保存为图片阵列,并可视化,方便对比。代码如下:

# 重建图片,从测试集采样一批图片
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))  # 打平并送入自编码器
x_hat = tf.sigmoid(logits)  # 将输出转换为像素值,使用sigmoid函数
# 恢复为28×28,[b, 784] => [b, 28, 28]
x_hat = tf.reshape(x_hat, [-1, 28, 28])# 输入的前50张+重建的前50张图片合并,[b, 28, 28] => [2b, 28, 28]
x_concat = tf.concat([x, x_hat], axis=0)
x_concat = x_hat
x_concat = x_concat.numpy() * 255.  # 恢复为0~255范围
x_concat = x_concat.astype(np.uint8)  # 转换为整型
save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)  # 保存图片

 图片重建的效果如下图所示,其中每张图片的左边5列为真实图片,右边5列为对应的重建图片。

第1个Epoch

第50个Epoch

第100个Epoch

可以看到,第一个Epoch时,图片重建效果交叉,图片非常模糊,逼真度较差;随着训练的进行,重建图片边缘越来越清晰,第100个Epoch时,重建的图片效果以及比较接近真实图片。

 这里的save_images函数负责将多张图片合并并保存为一张大图,这部分代码使用PIL图片库完成图片阵列逻辑,代码如下:

def save_images(imgs, name):# 创建280×280大小的图片阵列new_im = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):  # 10行图片阵列for j in range(0, 280, 28):  # 10列图片阵列im = imgs[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))  # 写入对应位置index += 1# 保存图片阵列new_im.save(name)

完整代码

import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
import sslfrom Chapter12.Fashion_MNIST_dataload import get_datassl._create_default_https_context = ssl._create_unverified_context
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')def save_images(imgs, name):# 创建280×280大小的图片阵列new_im = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):  # 10行图片阵列for j in range(0, 280, 28):  # 10列图片阵列im = imgs[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))  # 写入对应位置index += 1# 保存图片阵列new_im.save(name)h_dim = 20
batchsz = 512
lr = 1e-3(x_train, y_train), (x_test, y_test) = get_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)class AE(keras.Model):def __init__(self):super(AE, self).__init__()# 创建Encoders网络,实现在自编码器类的初始化函数中self.encoder = Sequential([layers.Dense(256, activation=tf.nn.relu),layers.Dense(128, activation=tf.nn.relu),layers.Dense(h_dim)])# 创建Decoders网络self.decoder = Sequential([layers.Dense(128, activation=tf.nn.relu),layers.Dense(256, activation=tf.nn.relu),layers.Dense(784)])def call(self, inputs, training=None):# [b, 784] => [b, 10]h = self.encoder(inputs)# [b, 10] => [b, 784]x_hat = self.decoder(h)return x_hat# 创建网络对象
model = AE()
# 指定输入大小
model.build(input_shape=(None, 784))
# 打印网络信息
model.summary()
# 创建优化器,并设置学习率
optimizer = tf.optimizers.Adam(lr=lr)for epoch in range(100):  # 训练100个Epochfor step, x in enumerate(train_db):  # 遍历训练集#  打平,[b, 28, 28] => [b, 784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits = model(x)# 计算重建图片与输入之间的损失函数rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)# 计算均值rec_loss = tf.reduce_mean(rec_loss)# 自动求导,包含了2个子网络的梯度grads = tape.gradient(rec_loss, model.trainable_variables)# 自动更新,同时更新2个子网络optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:# 间隔性打印训练误差print(epoch, step, float(rec_loss))# evaluation# 重建图片,从测试集采样一批图片x = next(iter(test_db))logits = model(tf.reshape(x, [-1, 784]))  # 打平并送入自编码器x_hat = tf.sigmoid(logits)  # 将输出转换为像素值,使用sigmoid函数# 恢复为28×28,[b, 784] => [b, 28, 28]x_hat = tf.reshape(x_hat, [-1, 28, 28])# 输入的前50张+重建的前50张图片合并,[b, 28, 28] => [2b, 28, 28]x_concat = tf.concat([x, x_hat], axis=0)x_concat = x_hatx_concat = x_concat.numpy() * 255.  # 恢复为0~255范围x_concat = x_concat.astype(np.uint8)  # 转换为整型save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)  # 保存图片

深度学习之自编码器(2)Fashion MNIST图片重建实战相关推荐

  1. Fashion MNIST图片重建实战(AE)

    Fashion MNIST 是一个定位在比 MNIST 图片识别问题稍复杂的数据集,它的设定与MNIST 几乎完全一样,包含了 10 类不同类型的衣服.鞋子.包等灰度图片,图片大小为28 × 28,共 ...

  2. 深度学习之自编码器(5)VAE图片生成实战

    深度学习之自编码器(5)VAE图片生成实战 1. VAE模型 2. Reparameterization技巧 3. 网络训练 4. 图片生成 VAE图片生成实战完整代码  本节我们基于VAE模型实战F ...

  3. 深度学习之自编码器AutoEncoder

    深度学习之自编码器AutoEncoder 原文:http://blog.csdn.net/marsjhao/article/details/73480859 一.什么是自编码器(Autoencoder ...

  4. 深度学习之自编码器(1)自编码器原理

    深度学习之自编码器(1)自编码器原理 自编码器原理  前面我们介绍了在给出样本及其标签的情况下,神经网络如何学习的算法,这类算法需要学习的是在给定样本 x\boldsymbol xx下的条件概率 P( ...

  5. 深度学习之自编码器(4)变分自编码器

    深度学习之自编码器(4)变分自编码器 1. VAE原理  基本的自编码器本质上是学习输入 x\boldsymbol xx和隐藏变量 z\boldsymbol zz之间映射关系,它是一个 判别模型(Di ...

  6. 深度学习之自编码器(3)自编码器变种

    深度学习之自编码器(3)自编码器变种 1. Denoising Auto-Encoder 2. Dropout Auto-Encoder 3. Adversarial Auto-Encoder  一般 ...

  7. 【深度学习】 自编码器(AutoEncoder)

    目录 RDAE稳健深度自编码 自编码器(Auto-Encoder) DAE 深度自编码器 RDAE稳健深度自编码 自编码器(Auto-Encoder) AE算法的原理 Auto-Encoder,中文称 ...

  8. 深度学习(主要是CNN)用于图片的分类和检测总结

     深度学习(主要是CNN)用于图片的分类和检测总结 2014-12-4阅读920 评论0 前言: 主要总结一下自己最近看文章和代码的心得. 1. CNN用于分类:具体的过程大家都知道,无非是卷积, ...

  9. 深度学习之生成对抗网络(8)WGAN-GP实战

    深度学习之生成对抗网络(8)WGAN-GP实战 代码修改 完整代码 WGAN WGAN_train 代码修改  WGAN-GP模型可以在原来GAN代码实现的基础上仅做少量修改.WGAN-GP模型的判别 ...

最新文章

  1. python csv读取-Python对于CSV文件的读取与写入
  2. asp.net导出数据到Excel
  3. 表的插入、更新、删除、合并操作_8_手工插入数据
  4. Java UDP协议传输
  5. enquire.js-响应css媒体查询的轻量级javascript库
  6. 滴滴Booster移动APP质量优化框架 学习之旅 三
  7. ad09只在一定范围内查找相似对象_dxp查找相似对象
  8. php上传文件 服务器内部错误,php – 在将图像上传到S3时遇到内部服务器错误500...
  9. android layerlist bitmap,android shape类似的 另一个 高端用法:layer-list
  10. Linux系统编程40:多线程之基于环形队列的生产者与消费者模型
  11. drool 7.x 属性 : lock-on-active
  12. PHP开源CMS介绍
  13. object对象进行深拷贝
  14. python基础编程语法-Python编程入门——基础语法详解
  15. 哪有简明python教程下载_简明python教程在哪买!《简明python教程》 下载地址?
  16. Landsat8处理小工具(python)
  17. Dw cs6的详细下载安装教程对网页设计需要cs6的同学
  18. 基于Matlab的棋盘光栅的设计
  19. getc/fgetc
  20. postgresql.conf log_rotation_size

热门文章

  1. matlab swt函数,matlab swt 函数出错
  2. 锁定计算机的mad命令,本次操作由于这台计算机的限制而被取消
  3. The executable was signed with invalid entitlements
  4. java插入数据库字符串拼接_JAVA字符串怎么连接?
  5. 200(强缓存)和304(协商缓存)的区别
  6. SpringBoot是如何解析参数的
  7. Django模板渲染——(二)
  8. angular 在IIS部署运行
  9. 2.平凡之路-初识MyBatis
  10. Spring Boot 动态数据源(Spring 注解数据源)