本节主要介绍自编码的相关内容。

区别于以前内容的是,自编码过程并不需要标签。只需要数据集即可。


import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np# 参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_MNIST = True#true表示已经下载,
N_TEST_IMG = 5## train_data:下载的所有的数据
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,                                     transform=torchvision.transforms.ToTensor(),    download=DOWNLOAD_MNIST,
)#train_loader数据分份,每组BATCH_SIZE个
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)#本节不用定义test_data,因为只需要生成的data与train——data对比class AutoEncoder(nn.Module):def __init__(self):super(AutoEncoder, self).__init__()self.encoder = nn.Sequential(nn.Linear(28*28, 128),nn.Tanh(),nn.Linear(128, 64),nn.Tanh(),nn.Linear(64, 12),nn.Tanh(),nn.Linear(12, 3),   # 压缩到三个特征)self.decoder = nn.Sequential(nn.Linear(3, 12),nn.Tanh(),nn.Linear(12, 64),nn.Tanh(),nn.Linear(64, 128),nn.Tanh(),nn.Linear(128, 28*28),nn.Sigmoid(), #将输出值压缩到0-1的范围      )def forward(self, x):encoded = self.encoder(x)decoded = self.decoder(encoded)return encoded, decodedautoencoder = AutoEncoder()optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()   #  view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())for epoch in range(EPOCH):for step, (x, b_label) in enumerate(train_loader):b_x = x.view(-1, 28*28)   b_y = x.view(-1, 28*28)   encoded, decoded = autoencoder(b_x)loss = loss_func(decoded, b_y)      optimizer.zero_grad()               loss.backward()                     optimizer.step()                    if step % 100 == 0:print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())_, decoded_data = autoencoder(view_data)for i in range(N_TEST_IMG):a[1][i].clear()a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')a[1][i].set_xticks(()); a[1][i].set_yticks(())plt.draw(); plt.pause(0.05)plt.ioff()
plt.show()view_data = train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.
encoded_data, _ = autoencoder(view_data)
fig = plt.figure(2); ax = Axes3D(fig)
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
plt.show()

PyTorch学习(十一)encoded,decoded相关推荐

  1. PyTorch框架学习十一——网络层权值初始化

    PyTorch框架学习十一--网络层权值初始化 一.均匀分布初始化 二.正态分布初始化 三.常数初始化 四.Xavier 均匀分布初始化 五.Xavier正态分布初始化 六.kaiming均匀分布初始 ...

  2. 莫烦---Pytorch学习

    今天翻翻资料,发现有些地方的说明不太到位,修改过来了. Will Yip 2020.7.29 莫烦大神Pytorch -->> 学习视频地址 2020年开年就遇上疫情,还不能上学,有够难受 ...

  3. 莫烦pytorch学习笔记5

    莫烦pytorch学习笔记5 1 自编码器 2代码实现 1 自编码器 自编码,又称自编码器(autoencoder),是神经网络的一种,经过训练后能尝试将输入复制到输出.自编码器(autoencode ...

  4. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  5. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  6. pytorch 学习中安装的包

    记录pytorch学习遇到的包 1.ImportError: cannot import name 'PILLOW_VERSION' torchvision 模块内import pillow的时候发现 ...

  7. pytorch学习笔记(二):gradien

    pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030

  8. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  9. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

最新文章

  1. pandas把dataframe的数据列转化为索引列实战:单列转化为索引、多列转化为复合索引
  2. 腾讯云推出竞价实例 云服务器开销最高下降90%
  3. 假设写一段代码引导PC开机这段代码是 ? Here is a tiny quot;OSquot; :-D
  4. python教程书籍推荐-买Python入门书籍,我推荐这一本
  5. 编写图形界面程序,接受用户输入的5个浮点数据和一个文件目录名,将这五个数据保存在该文件中,再从文件中读取出来并且进行从大到小排序,然后再一次追加保存在该文件中。
  6. 怎样用jQuery拿到select中被选中的option的值
  7. Linux 权限管理: 权限的概念、权限管理、文件访问权限的设置、 粘滞位
  8. 求一个有限长度字符串 最长的有序可重复字符串长度
  9. java mongodb 删除字段类型_Mongodb基本数据类型、常用命令之增加、更新、删除
  10. java long.max_value,Long + Long不大于Long.MAX_VALUE
  11. DeepEye:一个基于深度学习的程序化交易识别与分类方法
  12. 信息熵,条件熵,相对熵,交叉熵
  13. 11gR2conceptes Memory Architecture中文翻译
  14. 【To Understand !!! DP or 递归】LeetCode 87. Scramble String
  15. iOS ReactiveCocoa 最全常用API整理
  16. html相册魔方代码,魔方相册制作方法现成的魔方相册代码:
  17. 刽子手c语言,古代神秘职业:刽子手的祖师爷
  18. 【数据挖掘】从“文本”到“知识”:信息抽取(Information Extraction)
  19. UE4 坐标系坐标轴旋转轴
  20. 从视频中截取图像opencv python

热门文章

  1. 如何与新同事共同成长?
  2. 基于51单片机的校园教室打铃系统
  3. 数据结构之SWUSTOJ954: 单链表的链接
  4. 用 StarRocks on ES 实现 分词
  5. 数码管显示“0~F”的共阳共阴数码管编码表
  6. BIDI单纤双向光模块
  7. Win7 突然没声音 无法播放测试音调
  8. Spring 框架之九阴真经
  9. Redis——Redis入门和一些笔记
  10. post /login.php http/1.1,路由器登录入口:http://192.168.10.1