import tensorflow as  tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as keras
from tensorflow.keras import datasets,layers,optimizers,losses,Sequentialbatchsz=128
lr=0.01
h_dim=20
#加载Fashion MNIST图片数据集
(x_train,y_train),(x_test,y_test)=datasets.fashion_mnist.load_data()
#归一化
x_train,x_test=x_train.astype(np.float32)/255.,x_test.astype(np.float32)/255.
#只需要通过图片数据即可构建数据集对象,不需要标签
train_db=tf.data.Dataset.from_tensor_slices(x_train)
test_db=tf.data.Dataset.from_tensor_slices(x_test)
#构建测试集对象
train_db=train_db.shuffle(batchsz*5).batch(batchsz)
test_db=test_db.shuffle(batchsz*5).batch(batchsz)class AE(tf.keras.Model):def __init__(self):super(AE,self).__init__()#创建Enconder网络,实现在自编码器类的初始化函数中self.encoder=Sequential([layers.Dense(256,activation='relu'),layers.Dense(128,activation='relu'),layers.Dense(h_dim)])#创建Deconder网络self.decoder=Sequential([layers.Dense(128,activation='relu'),layers.Dense(256,activation='relu'),layers.Dense(784)])def call(self,inputs,training=None):#前向传播函数#编码获得隐藏向量h,[b,784]->[b,20]h=self.encoder(inputs)#解码获得重建图片,[b,20]->[b,784]x_hat=self.decoder(h)return x_hat
#创建网络对象
model=AE()
#指定输入大小
model.build(input_shape=(4,784))
#打印网络信息
model.summary()
#创建优化器,并放置学习率
optimizer=optimizers.Adam(learning_rate=lr)loss=[]
for epoch in range(10):for step,x in enumerate(train_db):x=tf.reshape(x,[-1,784])with tf.GradientTape() as tape:#梯度记录器#前向计算获得重建的图片x_rec_logits=model(x)#计算重建图片与输入之间的损失函数rec_loss=tf.nn.sigmoid_cross_entropy_with_logits(labels=x,logits=x_rec_logits)rec_loss=tf.reduce_mean(rec_loss)#自动求导,包含两个子网络的梯度grades=tape.gradient(rec_loss,model.trainable_variables)#自动更新,同时更新两个子网络optimizer.apply_gradients(zip(grades,model.trainable_variables))if step%100==0:loss.append(float(rec_loss))print("epoch:{},step:{},rec_loss:{}".format(epoch ,step,rec_loss))
#画出训练误差图像
plt.figure()
x=[i*5 for i in range(len(loss))]
plt.plot(x,loss,color='C1',marker='s',label='训练')
plt.xlabel('step')
plt.ylabel('loss')
plt.legend()
plt.show()


Fashion MNIST自编码器网络实战相关推荐

  1. 深度学习之自编码器(2)Fashion MNIST图片重建实战

    深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码  自编码器 ...

  2. tensorflow卷积神经网络实战:Fashion Mnist 图像分类与人马分类

    卷积神经网络实战:Fashion Mnist 图像分类与人马分类 一.FashionMnist的卷积神经网络模型 1.卷积VS全连接 2.卷积网络结构 3.卷积模型结构 1)Output Shape ...

  3. Pytorch初学实战(一):基于的CNN的Fashion MNIST图像分类

    1.引言 1.1.什么是Pytorch PyTorch是一个开源的Python机器学习库. 1.2.什么是CNN 卷积神经网络(Convolutional Neural Networks)是一种深度学 ...

  4. Fashion MNIST图片重建实战(AE)

    Fashion MNIST 是一个定位在比 MNIST 图片识别问题稍复杂的数据集,它的设定与MNIST 几乎完全一样,包含了 10 类不同类型的衣服.鞋子.包等灰度图片,图片大小为28 × 28,共 ...

  5. TensorFlow中的Fashion MNIST图像识别实战

    1.导入相应的库: 关于Fashion MNIST数据集的介绍:看这位博主: https://blog.csdn.net/qq_28869927/article/details/85079808 im ...

  6. fashionmnist数据集_Keras实现Fashion MNIST数据集分类

    本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...

  7. 【人工智能项目】Fashion Mnist识别实验

    [人工智能项目]Fashion Mnist识别实验 本次主要通过四个方法对fashion mnist进行识别实验,主要为词袋模型.hog特征.mlp多层感知器和cnn卷积神经网络.那么话不多说,走起来 ...

  8. Fashion MNIST数据集的处理——“...-idx3-ubyte”文件解析

    Fashion MNIST MNIST数据集可能是计算机视觉所接触的第一个图片数据集.而 Fashion MNIST 是在遵循 MNIST 的格式和大小的基础上,提升了一定的难度,在比较算法的性能时可 ...

  9. 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Auto Encoder 自编码器 + PyTorch实战

    文章目录 一.Basic Idea of Auto Encoder 1.1 Auto Encoder 结构 1.2 Auto Encoder 降维 1.3 Why Auto Encoder 1.4 D ...

最新文章

  1. 将txt文件和excel文件导入SQL2000数据库
  2. android 按下缩小效果松开恢复_iPhone XS/XS Max如何强制重启?如何进入恢复模式或DFU模式?...
  3. 实验五 网络编程与安全-----实验报告
  4. 吃亏受苦、前途未卜,Nature调查显示博士生三分之一可能抑郁
  5. 怎么用nuget程序包管理器安装jquery_Nuget服务器
  6. Linux系统基本操作(一)—光盘挂载/卸载
  7. hdu 1233还是畅通工程 最小生成树(入门题)prim算法
  8. 相片审核处理工具步骤_相片
  9. python常用模块之os
  10. Windows 的数据恢复工具
  11. 计算机分盘介质受写入保护,硬盘介质受写入保护怎么办
  12. 别人都是笑起来很好看,但是你却不一样,你是看起来很好笑。
  13. 评论:后MWC2012的一些感悟
  14. 单元格颜色公式之明细数据项隔行底纹
  15. Fantasy of a Summation LightOJ - 1213
  16. mongodb如何记录慢查询
  17. 论文领读|基于 VQVAE 的长文本生成
  18. 如何修改美食大战老鼠服务器,《美食大战老鼠》联运区组停止运营公告
  19. 管理员身份获得 SYSTEM 权限的四种方法
  20. 弱符号与弱引用 -> 程序员的自我修养 第3,4章笔记

热门文章

  1. Linux内核网络栈1.2.13-网卡设备的初始化流程
  2. F5刷新以及计算几秒钟的代码
  3. OpenCV4 部署DeepLabv3+模型
  4. 【OpenCV 4开发详解】保存和读取XML和YMAL文件
  5. 最大流学习笔记(1)
  6. Python 23天 序列化
  7. python基础(四)集合
  8. SpringBoot操作使用Spring-Data-Jpa
  9. 从Netflix的Hystrix框架理解服务熔断和服务降级
  10. 在js中为图片的src赋值时,src的值不能在开头用 破浪号~