Pytorch实现自编码器
原文地址
分类目录——Pytorch
- 什么是编码器
有一中数据压缩的、降维的意思
举个例子来说明,同一张图片,高清的和标清的我们都能识别出图片中的内容(这里就考虑识别这一个需求,其他需求暂不考虑),这是因为即使是标清的图片,也保留了进行识别的关键特征
。但是高清的在无论是在保存,还是在提取上都会更费工夫。深度学习处理起来亦是如此,深度学习会包含很多层,每层节点也很多,这种情况下,如果输入数据的规模太大,神经网络也很难训练出结果。那么,能在保留关键特征的基础上对数据尽心降维,就是一项一劳永逸的活动。
这个编码器要怎么用呢
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1nZbFk6U-1582679206189)(https://morvanzhou.github.io/static/results/ML-intro/auto3.png)]
图片引自 什么是自编码 (Autoencoder)
编码器的构造跟自己的应用(或者分类,或者回归)上两套体系。编码器也是一个完整的训练流程,虽说叫编码器,其实其内部包括编码(上图中的压缩)和解码(上图中的解压)两部分,编码用来降维,解码用来将维度回复,通过维度恢复的数据(上图中的黑色X)与原始数据(上图的白色X)的误差来训练编码器参数,训练完成后编码部分将能压缩到原始数据的关键特征,极大地加速训练过程。
另外我觉得,翻译是一个很好的例子,自己有中思路可不可以做一种压缩(编码)一种万国语,存放在计算机中,计算机能识别的;甚至可以跨越表达方式,比如‘狗’、‘dog’ 另外还有一张狗的图片。他们在计算机中的表现形式是一样的,通过不同的模型可以翻译成‘狗’、‘dog’ 和狗的图片。
下面用一个例子来说明
这个程序的数据是手写数字识别的图片,分辨率为28*28,通过编码器将28*28维度的像素维度降维到3维;然后用3维数据在三维坐标平面内进行了可视化;最后用svm就编码之后的3维数据进行分类,因为压缩之后只有3个维度,为了节约时间只用了1000个训练数据,所以最终的准确率并没有很高。
导入支持包与设置超参数
import torch import torch.nn as nn import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot3d import Axes3D import os import numpy as np from sklearn import svm from sklearn.model_selection import GridSearchCV# 超参数 EPOCH = 10 BATCH_SIZE = 64 LR = 0.005 if os.path.exists('mnist/'): # 如果已经存在(下载)了就不用下载了DOWNLOAD_MNIST = False else:DOWNLOAD_MNIST = True # 下过数据的话, 就可以设置成 False N_TEST_IMG = 5 # 到时候显示 5张图片看效果, 如上图一
获得手写数字图片数据
####################################### 获取手写数字图片数据 train_data = torchvision.datasets.MNIST(root='./mnist/',train=True, # this is training datatransform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]download=DOWNLOAD_MNIST, # download it if you don't have it )test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28) train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=0 )
构造编码器
################################# 构造编码器 class AutoEncoder(nn.Module):def __init__(self):super(AutoEncoder, self).__init__()# 编码网络self.encoder = nn.Sequential(nn.Linear(28*28, 128),nn.Tanh(),nn.Linear(128, 64),nn.Tanh(),nn.Linear(64, 12),nn.Tanh(),nn.Linear(12, 3), # 压缩成3个特征, 是为了寿面好进行 3D 图像可视化# 当然也可以压缩到5个特征,选其中的三个来作图)# 解码网络self.decoder = nn.Sequential(nn.Linear(3, 12),nn.Tanh(),nn.Linear(12, 64),nn.Tanh(),nn.Linear(64, 128),nn.Tanh(),nn.Linear(128, 28*28),nn.Sigmoid(), # 激励函数让输出值在 (0, 1))def forward(self, x):encoded = self.encoder(x)decoded = self.decoder(encoded)return encoded, decoded # 定义一个编码器对象 autoencoder = AutoEncoder()
训练编码器
############################## 训练编码器 optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR) loss_func = nn.MSELoss()for epoch in range(EPOCH):for step, (x, b_label) in enumerate(train_loader):b_x = x.view(-1, 28*28) # batch x, shape (batch, 28*28)# b_y跟b_x是一样encoded_x, decoded_x = autoencoder(b_x)loss = loss_func(decoded_x, b_x) # 这里如果写成b_x会更容易裂解optimizer.zero_grad() # clear gradients for this training steploss.backward() # backpropagation, compute gradientsoptimizer.step() # apply gradients
可视化
########################### 画图的部分 # 取200个数据来作图 view_data = train_data.train_data[:200].view(-1, 28 * 28).type(torch.FloatTensor) / 255. encoded_data, _ = autoencoder(view_data) # 提取压缩的特征值 fig = plt.figure(2) ax = Axes3D(fig) # 3D 图 # x, y, z 的数据值 X = encoded_data.data[:, 0].numpy() Y = encoded_data.data[:, 1].numpy() Z = encoded_data.data[:, 2].numpy() values = train_data.train_labels[:200].numpy() # 标签值 for x, y, z, s in zip(X, Y, Z, values):c = cm.rainbow(int(255 * s / 9)) # 上色ax.text(x, y, z, s, backgroundcolor=c) # 标位子 ax.set_xlim(X.min(), X.max()) ax.set_ylim(Y.min(), Y.max()) ax.set_zlim(Z.min(), Z.max()) plt.show() # 注意这里进行了plt.show(),程序会停在这里,需要把图片关闭之后下面的程序才能进行,也可以调换一下跟下面svm分类部分替换位置
注意这里进行了plt.show(),程序会停在这里,需要把图片关闭之后下面的程序才能进行,也可以调换一下跟下面svm分类部分替换位置
用SVM对编码(压缩)后的数据进行数字识别
################################### 用SVM分类 # 取1000个训练数据来训练svm svm_train = train_data.train_data[:1000].view(-1, 28 * 28).type(torch.FloatTensor) / 255. s_t_x_afterencoder = autoencoder(svm_train)[0].data.numpy() print(s_t_x_afterencoder.shape()) s_t_y = train_data.train_labels[:1000].numpy() # 标签值 print(s_t_y.shape()) # 取1000个训练数据来测试 svm_test = test_data.test_data[:1000].view(-1, 28 * 28).type(torch.FloatTensor) / 255. s_te_x_afterencoder = autoencoder(svm_test)[0].data.numpy() s_te_y = test_data.test_labels[:1000].numpy() # 标签值c_can = np.logspace(-3, 2, 10) gamma_can = np.logspace(-3, 2, 10)model = svm.SVC(kernel='rbf', decision_function_shape='ovr', random_state=1) clf = GridSearchCV(model, param_grid={'C': c_can, 'gamma': gamma_can}, cv=5, n_jobs=5) clf.fit(s_t_x_afterencoder, s_t_y)print('测试集准确率:\t', clf.score(s_te_x_afterencoder, s_te_y)) # 因为压缩到了三个特征,准确率并不是很高 # 测试集准确率: 0.764
参考文献
什么是自编码 (Autoencoder)
AutoEncoder (自编码/非监督学习)
分类目录——Matplotlib
Pytorch实现自编码器相关推荐
- Pytorch:基于转置卷积解码的卷积自编码网络
Pytorch: 图像自编码器-卷积自编码网络(转置卷积解码)和图像去噪 Copyright: Jingmin Wei, Pattern Recognition and Intelligent Sys ...
- 【深度学习】资源:最全的 Pytorch 资源大全
Pytorch资源大全 目录 Pytorch及相关库 NLP和语音处理: 计算机视觉: 概率/生成库: 其他库: 教程和示例 论文的实现 Pytorch其他 Pytorch及相关库 pytorch:P ...
- 跟我一起学PyTorch-07:嵌入与表征学习
前面介绍了深度神经网络和卷积神经网络,这些神经网络有个特点:输入的向量越大,训练得到的模型越大.但是,拥有大量参数模型的代价是昂贵的,它需要大量的数据进行训练,否则由于缺少足够的训练数据,就可能出现过 ...
- 深入浅出Transformer(一)
引言 Transformer的重要性不用多说了吧,NLP现在最火了两个模型--BERT和GPT,一个是基于它的编码器实现的,另一个是基于它的解码器实现的. 凡是我不能创造的,我都不能理解. 为了更好的 ...
- Facebook 开源增强版 LASER 库:可实现 93 种语言的零样本迁移...
雷锋网 AI 科技评论按:去年 12 月份,Facebook 在论文中提出了一种可学习 93 种语言的联合多语言句子表示的架构,该架构仅使用一个编码器,就可以在不做任何修改的情况下实现跨语言迁移,为自 ...
- Facebook 开源增强版 LASER 库:可实现 93 种语言的零样本迁移
雷锋网 AI 科技评论按:去年 12 月份,Facebook 在论文中提出了一种可学习 93 种语言的联合多语言句子表示的架构,该架构仅使用一个编码器,就可以在不做任何修改的情况下实现跨语言迁移,为自 ...
- 【13】变分自编码器(VAE)的原理介绍与pytorch实现
文章目录 1.VAE的设计思路 2.VAE的模型架构 3.VAE的作用原理 4.VAE的Pytorch实现 1)参考代码 2)训练结果展示 3)生成结果展示 5.实现VAE中出现的问题 1.VAE的设 ...
- Pytorch基础-07-自动编码器
自动编码器(AutoEncoder)是一种可以进行无监督学习的神经网络模型.一般而言,一个完整的自动编码器主要由两部分组成,分别是用于核心特征提取的编码部分和可以实现数据重构的解码部分. 1 自动编码 ...
- 【实战】(以色列·特拉维夫大学)将 E4E 成功移植到Windows 10: StyleGAN2图像处理编码器的设计,支持Pytorch Cuda/C++ Extension
StyleClip项目支持写一段文字,指导StyleGAN2生成具备指定特征的图像.但这个项目往往需要一些中间数据,比如:在Pytorch环境下对应于输入原图的StyleGAN2反演(Inversio ...
最新文章
- 魔改Attention大集合
- 以下选项不是python打开方式的是-以下选项中,不是Python对文件的打开模式的是...
- linux mate桌面主题下载_7款Linux桌面环境推荐,你值得拥有!
- ref out param 区别
- View.onMeasured的默认实现 (onMeasure必须调setMeasuredDimension)
- 第 11 章 Paragraphs
- nagios客户端nrped服务方式启动脚本
- android 开发 gradle 自己会容易混淆的东西
- [图解教程]Eclipse不可不知的用法之一:自动生成Getter、Setter和构造方法
- 矩阵运算_SLAM中用到的矩阵计算_基本公式及知识汇总
- 什么是OFD格式文档?一文教你读懂OFD格式文档
- 【朝花夕拾】【编程基础】一 存储单位
- win10安装CUDA和cuDNN
- redis通配符批量删除keys——del
- 狂神学习系列14:SpringCloud
- 淘宝店铺怎么上第四层级?有哪些技巧?
- vue vue-seamless-scroll 无缝滚动依赖
- 程序员的理想桌面装备,少不了一台2K高清的专业显示器
- 我远行,故我在——海陀行点滴感受
- 数据结构知识点思维导图(绪论)