RNN 特别适合做序列类型的数据,那么 RNN 能不能想 CNN 一样用来做图像分类呢?下面我们用 mnist 手写字体的例子来展示一下如何用 RNN 做图像分类,但是这种方法并不是主流,这里我们只是作为举例。RNN 做图像分类

import sys
sys.path.append('..')import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoaderfrom torchvision import transforms as tfs
from torchvision.datasets import MNIST# 定义数据
data_tf = tfs.Compose([tfs.ToTensor(),tfs.Normalize([0.5], [0.5]) # 标准化
])train_set = MNIST('./data', train=True, transform=data_tf)
test_set = MNIST('./data', train=False, transform=data_tf)train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)# 定义模型
class rnn_classify(nn.Module):def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):super(rnn_classify, self).__init__()self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers) # 使用两层 lstmself.classifier = nn.Linear(hidden_feature, num_class) # 将最后一个 rnn 的输出使用全连接得到最后的分类结果def forward(self, x):'''x 大小为 (batch, 1, 28, 28),所以我们需要将其转换成 RNN 的输入形式,即 (28, batch, 28)'''x = x.squeeze() # 去掉 (batch, 1, 28, 28) 中的 1,变成 (batch, 28, 28)x = x.permute(2, 0, 1) # 将最后一维放到第一维,变成 (28, batch, 28)out, _ = self.rnn(x) # 使用默认的隐藏状态,得到的 out 是 (28, batch, hidden_feature)out = out[-1, :, :] # 取序列中的最后一个,大小是 (batch, hidden_feature)out = self.classifier(out) # 得到分类结果return outnet = rnn_classify() criterion = nn.CrossEntropyLoss()  optimzier = torch.optim.Adadelta(net.parameters(), 1e-1)# 开始训练
from utils import train
train(net, train_data, test_data, 10, optimzier, criterion)

PyTorch 深度学习:38分钟快速入门——RNN 做图像分类相关推荐

  1. PyTorch 深度学习:37分钟快速入门——FCN 做语义分割

    语义分割是一种像素级别的处理图像方式,对比于目标检测其更加精确,能够自动从图像中划分出对象区域并识别对象区域中的类别 在 2015 年 CVPR 的一篇论文 Fully Convolutional N ...

  2. PyTorch 深度学习:36分钟快速入门——GAN

    自动编码器和变分自动编码器,不管是哪一个,都是通过计算生成图像和输入图像在每个像素点的误差来生成 loss,这一点是特别不好的,因为不同的像素点可能造成不同的视觉结果,但是可能他们的 loss 是相同 ...

  3. PyTorch 深度学习:34分钟快速入门——自动编码器

    自动编码器最开始是作为一种数据压缩方法,同时还可以在卷积网络中进行逐层预训练,但是随后更多结构复杂的网络,比如 resnet 的出现使得我们能够训练任意深度的网络,自动编码器就不再使用在这个方面,下面 ...

  4. PyTorch 深度学习:32分钟快速入门——ResNet

    ResNet 当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络子在 2015 年 Ima ...

  5. PyTorch 深度学习:33分钟快速入门——VGG

    CIFAR 10¶ cifar 10 这个数据集一共有 50000 张训练集,10000 张测试集,两个数据集里面的图片都是 png 彩色图片,图片大小是 32 x 32 x 3,一共是 10 分类问 ...

  6. PyTorch 深度学习:32分钟快速入门——DenseNet

    DenseNet¶ 因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 cvpr 2017 的 best paper,DenseNet. DenseNet ...

  7. PyTorch 深度学习:30分钟快速入门

    卷积¶ 卷积在 pytorch 中有两种方式,一种是 torch.nn.Conv2d(),一种是 torch.nn.functional.conv2d(),这两种形式本质都是使用一个卷积操作 这两种形 ...

  8. PyTorch 深度学习:35分钟快速入门——变分自动编码器

    变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成. 回忆一下,自动编码器有个问题,就是并不能任意生成图片,因为我们没有办法自己去构造隐藏向量,需要通过一张图片输入编 ...

  9. PyTorch 深度学习:31分钟快速入门——Batch Normalization

    Batch Normalization¶ 前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好.但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相 ...

最新文章

  1. Windows 下noinstall方式安装 mysql-5.7.5-m15-winx64
  2. python一个月能挣多少钱-零基础学python,我可以让你一个月上手做项目!
  3. Matlab实现线性回归和逻辑回归: Linear Regression Logistic Regression
  4. 重磅!《Android 全埋点技术白皮书》开源所有项目源码!
  5. C++字节序反转的实现算法(附完整源码)
  6. Docker image Introduce
  7. CSS学习笔记(更新中...)
  8. CV Papers|计算机视觉论文推荐周报20200504期
  9. 计算机网络操作系统课件,计算机网络操作系统课件(张浩军版).ppt
  10. 2021-2025年中国超声波管道监测系统行业市场供需与战略研究报告
  11. Java 调用http接口
  12. 查找出现次数 oracle,ORACLE计算某个列中出现次数最多的值
  13. 每日算法系列【LeetCode 153】寻找旋转排序数组中的最小值
  14. uds 诊断协议的bootloader开发
  15. gif动图怎么制作?手把手教你视频转gif动图
  16. edison\arduino-1.5.3-Intel.1.0.3闪退
  17. Android 应用换肤功能(白天黑夜主题切换)
  18. c语言子函数作用是什么意思,C语言编译器中常见的函数用法以及作用详解
  19. 【Powerdesigner】DFD分层数据流图的画法
  20. mac磁盘工具中磁盘显示灰色

热门文章

  1. ChaiNext:主流代币回调
  2. FAL风控培训「六大场景下,模型分数如何应用?」
  3. 轻松搞定 Shell 玩转 HiveSQL
  4. C# 列表中查找大小比较
  5. visio转换成eps
  6. python encode和decode函数说明
  7. SQLServer查询锁表
  8. 使用bat快速创建cocos2d-x模板
  9. 【Vegas原创】Can't connect to X11 window server using ':0.0' 解决方法
  10. 第6章 面向方面编程