第一部分:这一部分是来自《TensorFlow深度学习》龙良曲老师的书籍中的代码:


import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras import  losses,optimizers,Sequential
#加载Fashion MNIST图片数据集
(train_x,train_y),(test_x,test_y)=tf.keras.datasets.fashion_mnist.load_data()
#归一化
train_x,test_x=train_x.astype(np.float32)/255.0,test_x.astype(np.float32)/255.0

EPOCHES=2
batch_size=64
learning_rate=0.0001
#只需要通过图片数据即可构建数据集对象,不需要标签
train_db=tf.data.Dataset.from_tensor_slices(train_x)
train_db=train_db.shuffle(10000)
train_db=train_db.batch(batch_size)
# train_db=train_db.repeat(5)
#构建测试集对象
test_db=tf.data.Dataset.from_tensor_slices(test_x)
test_db=test_db.batch(batch_size)
class AE(tf.keras.Model):def __init__(self):super(AE, self).__init__()# 创建Enconder网络,实现在自编码器类的初始化函数中self.encoder=Sequential([tf.keras.layers.Dense(256),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(128),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(20)])# 创建Deconder网络self.decoder=Sequential([tf.keras.layers.Dense(128),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(256),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(784)])def call(self,inputs,training=None):# 前向传播函数# 编码获得隐藏向量h,[b,784]->[b,20]out=self.encoder(inputs)# 解码获得重建图片,[b,20]->[b,784]out_put=self.decoder(out)return out_put

#创建网络对象
model=AE()
#指定输入大小
model.build(input_shape=(4,784))
#打印网络信息
model.summary()
#创建优化器,并放置学习率
optimizer=optimizers.Adam(learning_rate=learning_rate)
#保存图像
def Save_Image(img,filename):new_in = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):for j in range(0, 280, 28):im = img[index]im = Image.fromarray(im, mode='L')new_in.paste(im, (i, j))index += 1new_in.save(filename)
LOSS=[]for epoch in range(EPOCHES):for step,x in enumerate(train_db):x=tf.reshape(x,[-1,784])with tf.GradientTape() as tape:# 前向计算获得重建的图片out=model(x)# 计算重建图片与输入之间的损失函数loss=tf.losses.binary_crossentropy(x,out,from_logits=True)loss=tf.reduce_mean(loss)# 自动求导,包含两个子网络的梯度grads=tape.gradient(loss,model.trainable_variables)# 自动更新,同时更新两个子网络optimizer.apply_gradients(zip(grads,model.trainable_variables))if step%100==0:LOSS.append(float(loss))print(epoch,step,float(loss))x=next(iter(test_db))logits=model(tf.reshape(x,[-1,784]))x_hat=tf.sigmoid(logits)x_hat=tf.reshape(x_hat,[-1,28,28])x_concat=tf.concat([x[:50],x_hat[:50],x_hat],axis=0)x_concat=x_concat.numpy()*255.0x_concat=x_concat.astype(np.uint8)Save_Image(x_concat,r'E:\python教学\图像识别\自编码器\images_AE\\%d.png'%epoch)plt.figure(figsize=(6,6))
x=[i for i in range(len(LOSS))]
plt.plot(x,LOSS,label='loss',linestyle='-',color='blue')
plt.xlabel('X')
plt.ylabel('loss')
plt.legend()
plt.show()

实验效果:

以上在写代码的时候需要注意一个问题:就是在Save_Image中的图像点阵图大小的问题,Save_Image中定义的步长为28,所以最后输出的图像是10*10的图像阵列,每一次x=next(iter(test_db))迭代得到图像是(32,28,28),意味着每一次的迭代得到的图像数为32,通过后面的拼接:x_concat=tf.concat([x[:50],x_hat[:50]],axis=0)之后的图像为(64,28,28)虽然这个地方写的是x[:50]读取前50张图像,但是实际上得到的前32张,所以10*10张图像肯定是超出了索引的范围,也就是在运行代码的时候会出现:索引越界的错误
所以这个地方该怎么改呢:可以将Save_Image中的步长设置为56,增加步长,只不过得到的图像数少一点,25张。将这个x_concat=tf.concat([x[:50],x_hat[:50]],axis=0)可以修改为这个x_concat=tf.concat([x[:15],x_hat[:15]],axis=0)。就可以了,只要索引的范围不越界就行。
第二部分:对上面的代码进行改进

import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.keras import  losses,optimizers,Sequential
#加载Fashion MNIST图片数据集
(train_x,train_y),(test_x,test_y)=tf.keras.datasets.fashion_mnist.load_data()
#归一化
train_x,test_x=train_x.astype(np.float32)/255.0,test_x.astype(np.float32)/255.0
train_x=tf.reshape(train_x,[-1,784])
print(np.shape(train_x))
batch_size=64
learning_rate=0.0001
encoder=Sequential([tf.keras.layers.Dense(256),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(128),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(20)
])decoder=Sequential([tf.keras.layers.Dense(128),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(256),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(784)
])autoencoder=Sequential([encoder,decoder
])
#创建优化器,并放置学习率
optimizer=optimizers.Adam(learning_rate=learning_rate)
#保存图像
def Save_Image(img,filename):new_in=Image.new('L',(280,280))index=0for i in range(0,280,56):for j in range(0,280,56):im=img[index]im=Image.fromarray(im,mode='L')new_in.paste(im,(i,j))index+=1new_in.save(filename)
autoencoder.compile(optimizer=optimizer,loss=losses.binary_crossentropy)
autoencoder.fit(train_x,train_x,batch_size=batch_size*2,epochs=2,verbose=1)

for epoch in range(2):x=next(iter(test_db))print(np.shape(x))logits=model.predict(tf.reshape(x,[-1,784]))x_hat=tf.sigmoid(logits)x_hat=tf.reshape(x_hat,[-1,28,28])x_concat=tf.concat([x[:15],x_hat[:15]],axis=0)print(np.shape(x_concat))x_concat=x_concat.numpy()*255.0x_concat=x_concat.astype(np.uint8)Save_Image(x_concat,r'E:\python教学\图像识别\自编码器\images_AE\%d.png'%epoch)

实验效果:


最好还是多训练几代,我这里只训练了2代。
第三部分:采用卷积神经网络实现:

import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.keras import  losses,optimizers,Sequential
#加载Fashion MNIST图片数据集
(train_x,train_y),(test_x,test_y)=tf.keras.datasets.fashion_mnist.load_data()
#归一化
train_x,test_x=train_x.astype(np.float32)/255.0,test_x.astype(np.float32)/255.0
train_x=tf.expand_dims(train_x,axis=-1)
print(np.shape(train_x))
batch_size=64
learning_rate=0.0001
def build_Model(input_shape=(28,28,1)):input=tf.keras.layers.Input(input_shape)#接收大小为[b,28,28,1]=>[b,14,14,64]x=tf.keras.layers.Conv2D(8,kernel_size=[1,1],strides=[1,1],padding='same')(input)x = tf.keras.layers.MaxPool2D(pool_size=[2,2], padding='same')(x)print(x.shape)#[b,14,14,64]=>[b,7,7,128]x=tf.keras.layers.Conv2D(16,kernel_size=[1,1],strides=[1,1],padding='same')(x)x = tf.keras.layers.MaxPool2D(pool_size=[2,2], padding='same')(x)print(x.shape)#[b,7,7,128]=>[b,4,4,256]x=tf.keras.layers.Conv2D(32,kernel_size=[3,3],strides=[1,1],padding='same')(x)x = tf.keras.layers.MaxPool2D(pool_size=[2,2], padding='same')(x)print(x.shape)#[b,4,4,256]=>[b,2,2,512]x=tf.keras.layers.Conv2D(64,kernel_size=[1,1],strides=[1,1],padding='same')(x)x = tf.keras.layers.MaxPool2D(pool_size=[2,2], padding='same')(x)print(x.shape)#[b,2,2,256]=>[b,4,4,512]x=tf.keras.layers.Conv2DTranspose(64,kernel_size=[3,3],strides=[1,1],padding='valid')(x)print(x.shape)#[b,4,4,128]=>[b,7,7,256]x=tf.keras.layers.Conv2DTranspose(32,kernel_size=[4,4],strides=[1,1],padding='valid')(x)
#     x=tf.keras.layers.BatchNormalization()(x)
#     x=tf.keras.layers.Activation('relu')(x)print(x.shape)#[b,7,7,64]=>[b,14,14,128]x=tf.keras.layers.Conv2DTranspose(16,kernel_size=[2,2],strides=[2,2],padding='valid')(x)
#     x=tf.keras.layers.BatchNormalization()(x)
#     x=tf.keras.layers.Activation('relu')(x)print(x.shape)#[b,14,14,32]=>[b,28,28,64]x=tf.keras.layers.Conv2DTranspose(8,kernel_size=[2,2],strides=[2,2],padding='valid')(x)
#     x=tf.keras.layers.BatchNormalization()(x)
#     x=tf.keras.layers.Activation('relu')(x)print(x.shape) out=tf.keras.layers.Conv2D(1, kernel_size=[3,3],  padding='same')(x)
#     out=tf.keras.layers.Activation('relu')(out)print(out.shape)model=tf.keras.Model(input,out)return modelmodel=build_Model()
model.summary()

#创建优化器,并放置学习率
optimizer=optimizers.Adam(learning_rate=learning_rate)

#保存图像
def Save_Image(img,filename):new_in=Image.new('L',(280,280))index=0for i in range(0,280,28):for j in range(0,280,28):im=img[index]im=Image.fromarray(im,mode='L')new_in.paste(im,(i,j))index+=1new_in.save(filename)
model.compile(optimizer=optimizer,loss=losses.binary_crossentropy)
model.fit(train_x,train_x,batch_size=batch_size*2,epochs=10,verbose=1)
test=tf.expand_dims(test_x,axis=3)
print(np.shape(test[:50]))
import cv2
x=test[:50]
print(np.shape(x))
for epoch in range(10):x=tf.reshape(x,[-1,28,28])logits=model.predict(x)plt.imshow(tf.reshape(logits[:1],[28,28]))
#     print('logits: ',np.shape(logits))print(type(logits))x_hat=tf.sigmoid(logits)x_hat=tf.reshape(x_hat,[-1,28,28])
#     print('x_hat: ',np.shape(x_hat))x_concat=tf.concat([x[:50],x_hat[:50]],axis=0)
#     print('x_concat: ',np.shape(x_concat))x_concat=x_concat.numpy()*255.0x_concat=x_concat.astype(np.uint8)Save_Image(x_concat,r'E:\python教学\图像识别\自编码器\images_AE\%d.png'%(epoch+2))
plt.show()



感觉效果比较差,需要继续对卷积神经网络中的参数进行调节。

TensorFlow自编码器(AE)实战相关推荐

  1. 《TensorFlow技术解析与实战》——第3章 可视化TensorFlow 3.1PlayGround

    本节书摘来自异步社区<TensorFlow技术解析与实战>一书中的第3章,第3.1节,作者李嘉璇,更多章节内容可以访问云栖社区"异步社区"公众号查看 第3章 可视化Te ...

  2. 《TensorFlow技术解析与实战》——第3章 可视化TensorFlow

    本节书摘来异步社区<TensorFlow技术解析与实战>一书中的第3章,作者:李嘉璇,更多章节内容可以访问云栖社区"异步社区"公众号查看. 第3章 可视化TensorF ...

  3. 《TensorFlow技术解析与实战》——1.2 什么是深度学习

    本节书摘来异步社区<TensorFlow技术解析与实战>一书中的第1章,第1.2节,作者:李嘉璇,更多章节内容可以访问云栖社区"异步社区"公众号查看. 1.2 什么是深 ...

  4. HEVC编码器设计实战-梅奥-专题视频课程

    HEVC编码器设计实战-342人已学习 课程介绍         该课程属于实战课程,通过该课程的学习,学员们可以开发一个简单的HEVC编码器,并在这个过程中加深对HEVC标准的理解. 课程收益    ...

  5. 斯坦福大学Tensorflow与深度学习实战课程

    分享一套Stanford University 在2017年1月份推出的一门Tensorflow与深度学习实战的一门课程.该课程讲解了最新版本的Tensorflow中各种概念.操作和使用方法,并且给出 ...

  6. 《纯干货-6》Stanford University 2017年最新《Tensorflow与深度学习实战》视频课程分享

    分享一套Stanford University 在2017年1月份推出的一门Tensorflow与深度学习实战的一门课程.该课程讲解了最新版本的Tensorflow中各种概念.操作和使用方法,并且给出 ...

  7. tensorflow自编码器+softmax对凯斯西储大学轴承数据进行故障分类

    先放参考链接,感谢大神们带来的启发: 凯斯西储大学轴承数据故障分类(使用卷积神经网络) TensorFlow实现MNIST识别(softmax) 前情回顾: tensorflow 自编码器+softm ...

  8. [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98%+

    [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98.8%+ 我们在博文,使用CNN做Kaggle比赛手写数字识别准确率99%+,在此基础之 ...

  9. 《TensorFlow技术解析与实战》——3.1 PlayGround

    本节书摘来异步社区<TensorFlow技术解析与实战>一书中的第3章,第3.1节,作者:李嘉璇,更多章节内容可以访问云栖社区"异步社区"公众号查看. 3.1 Play ...

  10. 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Auto Encoder 自编码器 + PyTorch实战

    文章目录 一.Basic Idea of Auto Encoder 1.1 Auto Encoder 结构 1.2 Auto Encoder 降维 1.3 Why Auto Encoder 1.4 D ...

最新文章

  1. php模拟post上传图片,php模拟post上传图片解决方法
  2. 为什么说 Serverless 引领云的下一个十年?
  3. 数据库监听。数据库一次notify,Activity多次接收
  4. C++的三种容器适配器
  5. 阿里巴巴DevOps实践指南 | 数字化转型下,DevOps的根本目标是什么?
  6. TCP协议下 Socket 与 ServerSocket
  7. 深度相机之TOF原理详解
  8. 今年颜宁在《自然》发表三篇论文仍归清华,网友:可惜以后不是了
  9. Linux| |对于UDP的学习
  10. 服务器并发性能报告,一般的服务器瞬时并发应该怎么样才算是合格呢?
  11. 机器人技术与人工智能有什么区别?
  12. delay() 方法
  13. 贝叶斯网络(数据预测)Python代码资源推荐
  14. 计算机文件自定义排序6,文件夹如何自定义排序
  15. c语言 close,C++ close()关闭文件方法详解
  16. 视觉识别真是火得发烫,依图科技宣布完成2亿美元融资
  17. 华为p4不是鸿蒙吗怎么又改为安卓_华为已将“基于安卓10”变成“兼容安卓10”,EMUI就是鸿蒙OS...
  18. html崩溃手机代码15,这12行代码分分钟让你电脑崩溃手机重启
  19. 百度的冬天:曾梦想成伟大公司 却为何陷入危机
  20. 育碧信条:AI 在手,天下我有

热门文章

  1. Python pip安装第三方库的国内镜像
  2. Linux之获取管理员权限的相关命令
  3. 参考别人博客,自己实现用idea运行eclipse项目--学生管理系统-
  4. 反思深度学习与传统计算机视觉的关系
  5. 基于OpenCV的实战:轮廓检测(附代码解析)
  6. 【OpenCV 4开发详解】均值滤波
  7. EP936E的IIC
  8. Python之操作RabbitMQ
  9. 将SQL for xml path('')中转义的字符正常显示
  10. 使用 Xbrowser4远程连接到 CentOS 7