自动编码器(autoencoder) 是神经网络的一种,该网络可以看作由两部分组成:一个编码器函数h = f(x) 和一个生成重构的解码器r = g(h)。传统上,自动编码器被用于降维或特征学习

自编码器原理示意图

编码器:将原始高维特征数据映射为低维度表征数据

解码器:将低纬度的压缩特征重构回原始数据

核心:输入特征等于输出特征

那么我们就会有一个疑问:压缩特征为什么小于输入特征?这里我们使用的是欠完备自动编码器:输入特征大于压缩特征.

下面我们手动编写一个自动编码器, 我们先来整理一下编写流程:

1. 获取数据

2. 模型搭建

3. 最优化设置

4. 模型训练

5. 数据3D可视化

然后呢, 我们就按照这个流程来编写我们的代码, 此代码的前提是有GPU, 代码才不会报错.

1. 获取数据

import numpy as np

import torch

import torchvision

import torchvision.transforms as tansforms

import torch.nn as nn

import torch.utils.data as data

# 数据集下载和导入

train_data = torchvision.datasets.MNIST(

root="./data/MNIST_data",

train=True,

transform=tansforms.ToTensor(),

download=True,

)

print(train_data.data.size()) # [60000,28,28]

print(train_data.targets.size()) # [60000]

# 显示图片

import matplotlib.pyplot as plt

# plt.imshow(train_data.data[10].numpy(),cmap="gray")

# plt.title("img_label:"+str(train_data.targets[10].numpy()))

# plt.show()

2. 模型搭建

class AutoEncoder(nn.Module):

def __init__(self):

super(AutoEncoder, self).__init__()

self.encoder = nn.Sequential(

nn.Linear(28 * 28, 3),

nn.Tanh(),

)

self.decoder = nn.Sequential(

nn.Linear(3, 28 * 28),

nn.ReLU(),

)

def forward(self, x):

encode = self.encoder(x)

decode = self.decoder(encode)

return encode, decode

3. 最优化设置

# 网络初始化

EPOCH = 10 # 训练周期

BATCH_SIZE = 64

LR = 0.005 # 学习率

autoencoder = AutoEncoder()

if torch.cuda.is_available():

autoencoder = autoencoder.cuda()

optim = torch.optim.Adam(autoencoder.parameters(), lr=LR)

loss_func = nn.MSELoss()

# 图像初始化,用于动态展示训练结果

N_test_img = 5

f, a = plt.subplots(2, N_test_img, figsize=(5, 2))

plt.ion() # 持续绘图

# 原始图像

view_data = train_data.data[:N_test_img].view(-1, 28 * 28).type(torch.FloatTensor)

for i in range(N_test_img):

a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap="gray")

a[0][i].set_xticks(())

a[0][i].set_yticks(())

上图为显示的效果图

4. 模型训练

# 训练网络

train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

for epoch in range(EPOCH):

for i, (x, y) in enumerate(train_loader):

if torch.cuda.is_available():

x = x.cuda()

b_x = x.view(-1, 28 * 28) # [batch_size,28*28]

encode, decode = autoencoder(b_x)

loss = loss_func(decode, b_x) # 均方误差损失函数

optim.zero_grad() # 清空梯度缓存

loss.backward()

optim.step()

if i % 500 == 0:

print("Epoch:", epoch, " train loss: ", loss.item())

# 绘制解码图像,第二行

_, decode = autoencoder(view_data.cuda())

for i in range(N_test_img):

a[1][i].clear()

a[1][i].imshow(np.reshape(decode.cpu().data.numpy()[i], (28, 28)), cmap="gray")

a[1][i].set_xticks(())

a[1][i].set_yticks(())

plt.draw()

plt.pause(0.05)

plt.ioff()

plt.show()

下图为显示的效果图:

5. 数据3D可视化

# 3D可视化

from mpl_toolkits.mplot3d import Axes3D

from matplotlib import cm

view_data = train_data.data[:200].view(-1,28*28).type(torch.FloatTensor)/255.

encode,_ = autoencoder(view_data.cuda())

fig =plt.figure()

ax = Axes3D(fig)

X = encode.cpu().data[:,0].numpy()

Y = encode.cpu().data[:,1].numpy()

Z = encode.cpu().data[:,2].numpy()

values = train_data.targets[:200].numpy()

for x,y,z,s in zip(X,Y,Z,values):

c = cm.rainbow(int(255*s/9))

ax.text(x,y,z,s,backgroundcolor=c)

ax.set_xlim(X.min(),X.max())

ax.set_ylim(Y.min(),Y.max())

ax.set_zlim(Z.min(),Z.max())

plt.show()

plt.pause(50)

下图为显示的效果图:

原文链接:https://blog.csdn.net/junjunzai123/article/details/107006799

自动编码器python_算法进阶(一)之自动编码器相关推荐

  1. 深度学习——无监督,自动编码器——尽管自动编码器与 PCA 很相似,but自动编码器既能表征线性变换,也能表征非线性变换;而 PCA 只能执行线性变换...

    自动编码器是一种有三层的神经网络:输入层.隐藏层(编码层)和解码层.该网络的目的是重构其输入,使其隐藏层学习到该输入的良好表征. 自动编码器神经网络是一种无监督机器学习算法,其应用了反向传播,可将目标 ...

  2. 阿里大佬总结的算法进阶指南,助你进大厂!

    大家好,我是林哥! 最近一个来自阿里的大佬总结了一份秋招算法进阶指南<LeetCode-Go>,全文一共有150多页,包含了所有常见的核心算法题目,助力大家在秋招末期拿到满意的Offer. ...

  3. 卧槽!阿里《算法进阶指南》火了,完整版 开放下载!

    最近一个来自阿里的大佬总结了一份秋招算法进阶指南<LeetCode-Go>,全文一共有150多页,包含了所有常见的核心算法题目,助力大家在秋招末期拿到满意的Offer. 以下是这份阿里秋招 ...

  4. Algorithm:【Algorithm算法进阶之路】之十大经典排序算法

    Algorithm:[Algorithm算法进阶之路]之十大经典排序算法 相关文章 Algorithm:[Algorithm算法进阶之路]之数据结构二十多种算法演示 Algorithm:[Algori ...

  5. Algorithm:【Algorithm算法进阶之路】之算法中的数学编程相关习题(时间速度、进制转换、排列组合、条件概率、斐波那契数列)

    Algorithm:[Algorithm算法进阶之路]之算法中的数学编程相关习题(时间速度.进制转换.排列组合.条件概率.斐波那契数列) 目录 时间速度 排列组合 进制转换 条件概率 斐波那契数列 时 ...

  6. Algorithm:【Algorithm算法进阶之路】之数据结构基础知识

    Algorithm:[Algorithm算法进阶之路]之数据结构基础知识 相关文章 Algorithm:[Algorithm算法进阶之路]之数据结构二十多种算法演示 Algorithm:[Algori ...

  7. Algorithm:【Algorithm算法进阶之路】之数据结构二十多种算法演示

    Algorithm:[Algorithm算法进阶之路]之数据结构二十多种算法演示 目录 一.数据结构算法 1.顺序表 2.链表 3.栈和队列 4.串的模式匹配 5.稀疏矩阵 6.广义表 7.二叉树 8 ...

  8. DL之NN/CNN:NN算法进阶优化(本地数据集50000张训练集图片),六种不同优化算法实现手写数字图片识别逐步提高99.6%准确率

    DL之NN/CNN:NN算法进阶优化(本地数据集50000张训练集图片),六种不同优化算法实现手写数字图片识别逐步提高99.6%准确率 目录 设计思路 设计代码 设计思路 设计代码 import mn ...

  9. 写给前端的算法进阶指南,我是如何两个月零基础刷200题 等推荐

    大家好,我是若川. 话不多说,这一次花了几小时精心为大家挑选了20余篇好文,供大家阅读学习.本文阅读技巧,先粗看标题,感兴趣可以都关注一波,一起共同进步. 前端从进阶到入院 作者ssh就职于字节跳动基 ...

最新文章

  1. 蓝牙模块引起电路干扰
  2. Microsoft Bot Framework 上手
  3. Qt 模式视图框架解读之委托
  4. 独家解读 | 滴滴机器学习平台架构演进之路
  5. NLP 训练及推理一体化工具(TurboNLPExp)
  6. 关于 $ Super $ $ 和 $ Sub $ $ 的用法
  7. Java访问静态常量_Java如何在Spring EL中访问静态方法或常量?
  8. mysql 创建分区索引吗_MySQL分区字段列有必要再单独建索引吗?
  9. 超高并发优化技能001--隔离
  10. Docker简单入门
  11. linux查询表空间脚本,通过Shell脚本查看数据库表空间使用情况
  12. 数的计数【Noip2001】
  13. 在软件开发中应用80:20原则
  14. 第四课曲面与曲线方程
  15. vue HTML内使用触底加载
  16. 阿里云OSS使用Java上传文件
  17. 柔性电子 压力传感器 strain-pressure sensor MoS2/graphene
  18. 少儿学单词软件android,推荐4款免费的自然拼读APP,孩子在家可以边玩边学!
  19. Gitee码云remote: error: File: , exceeds 100.00 MB 踩坑指南
  20. Python抓取妹子图,内含福利

热门文章

  1. php中db是指什么意思,phpmyadmin的作用是什么意思_数据库
  2. mysql 记录执行的sql_MySQL监控全部执行过的sql语句
  3. 【代码随想录】栈和队列
  4. 2021Autojs全网最全几十种小游戏和自阅合集 (含源码)
  5. linux 监控进程撤销,linux 系统监控脚本
  6. 求大数阶乘(存储在数组中)
  7. php 查找数组指定元素,PHP查找与搜索数组元素方法总结
  8. 【glib】Key-value文件解析器
  9. word中的表格如何修改文字方向
  10. 【BIM入门实战】Revit模型导入到第三方软件方法汇总