文章目录

  • 题目
  • 问题
    • CrossEntropy
    • 'bool' object is not iterable
    • 常见函数作用
  • 代码
  • 运行结果
  • 总结

题目

'''
Description: rnn--重新温习实现MNIST手写体识别
Autor: 365JHWZGo
Date: 2021-12-15 17:24:19
LastEditors: 365JHWZGo
LastEditTime: 2021-12-15 20:15:39
'''

问题

上一次写rnn手写体识别时,我用了batch_first=True,这次没有使用,重新理解了rnn中的维度变化

CrossEntropy

公式:torch.nn.CrossEntropyLoss()

在本例题中,我写的是

loss = loss_func(pre_out,label)

根据上述参数的要求
pre_out的size=(BATCH_SIZE,10),10也是类别数
label的size=(BATCH_SIZE,)

‘bool’ object is not iterable

这个问题出现在

accuracy = sum(pre_target == test_label.data.numpy())/2000.

这表示pre_target和test_label.data.numpy()的维度不统一,需要检查一下其维度大小,应该为(2000,)

一般出错在pre_target没有降维,没降之前的维度为(2000,1),直接用squeeze()降维

常见函数作用

函数名 作用
squeeze 移除数组中维度为1的维度
max output = torch.max(input, dim)
input是softmax函数输出的一个tensor dim是max函数索引的维度0/10是每列的最大值,1是每行的最大值
函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引
softmax dim:指明维度,dim=0表示按列计算;dim=1表示按行计算
torch将结果归一化
view 将维度展平

代码

import os
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import torch.autograd.variable as Variabletorch.manual_seed(1)# 超参数
BATCH_SIZE = 64
EPOCH = 1
LR = 0.01
DOWNLOAD_MNIST = False
TIME_STEP = 28
INPUT_SIZE = 28
HIDDEN_SIZE = 64# 判断MNIST数据集是否已经下载
if not os.path.exists('./mnist') or not os.listdir('./mnist'):DOWNLOAD_MNIST = True# 得到train_dataset
train_dataset = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MNIST
)# 得到test_dataset
test_dataset = torchvision.datasets.MNIST(root='./mnist',train=False,transform=torchvision.transforms.ToTensor()
)# 得到train_loader
train_loader = Data.DataLoader(dataset=train_dataset,shuffle=True,num_workers=2,batch_size=BATCH_SIZE
)# 得到test_data
test_data = test_dataset.test_data[:2000]/255.
# 得到test_label
test_label = test_dataset.test_labels[:2000]# 创建RNN类class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()# lstm=(INPUT_SIZE,HIDDEN_SIZE, NUM_LAYER)self.lstm = nn.LSTM(input_size=INPUT_SIZE,hidden_size=HIDDEN_SIZE,num_layers=1)self.linear = nn.Linear(HIDDEN_SIZE, 10)def forward(self, x):# r_output=(TIME_STEP,BATCH_SIZE,HIDDEN_SIZE)# hn=(NUM_LAYER,BATCH_SIZE,HIDDEN_SIZE)# cn=(NUM_LAYER,BATCH_SIZE,HIDDEN_SIZE)r_output, (hn, cn) = self.lstm(x, None)clsify0to9 = self.linear(r_output[-1])return clsify0to9if __name__ == '__main__':# 创建RNN实例rnn = RNN()# 创建优化器optim = torch.optim.Adam(rnn.parameters(), lr=LR)# 创建损失函数loss_func = nn.CrossEntropyLoss()# 训练for epoch in range(EPOCH):for i,(data,label) in enumerate(train_loader):# data=(BATCH_SIZE,CHANNELS,TIME_STEP,INPUT_SIZE)# label=(BATCH_SIZE)data = Variable(data.view(-1,TIME_STEP,INPUT_SIZE).transpose(0,1))label = Variable(label)# 使用rnn预测# rnn的输入维度为(TIME_STEP,BATCH_SIZE,INPUT_SIZE),所以需要展平为三个维度,并且第一个和第二个维度需要转变# rnn的输出维度为(BATCH_SIZE,10)pre_out = rnn(data)# 计算损失loss = loss_func(pre_out,label)# 优化optim.zero_grad()loss.backward()optim.step()if i % 100 == 0:# pre_test_label=(2000,10)# test_data.shape=[2000, 28, 28]# rnn的输入维度为(TIME_STEP,BATCH_SIZE,INPUT_SIZE),所以第一个和第二个维度需要转变pre_test_label = rnn(test_data.transpose(0,1))# input是softmax函数输出的一个tensor# dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值# softmax dimpre_target = torch.max(torch.softmax(pre_test_label,1),dim=1)[1].data.numpy().squeeze()# pre_target需要降维accuracy = sum(pre_target == test_label.data.numpy())/2000.print(f'epoch:{epoch} accuracy:{accuracy}')

运行结果

总结

话说温故而知新,可以为师矣。
话真不假,我今天重学之后,受益匪浅,希望接下来几天,将注意力机制融入其中。

rnn--重新温习实现MNIST手写体识别相关推荐

  1. TensorRT(3)-C++ API使用:mnist手写体识别

    本节将介绍如何使用tensorRT C++ API 进行网络模型创建. 1 使用C++ API 进行 tensorRT 模型创建 还是通过 tensorRT官方给的一个例程来学习. 还是mnist手写 ...

  2. python模拟手写笔迹_pytorch实现MNIST手写体识别

    本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下 实验环境 pytorch 1.4 Windows 10 python 3.7 cuda 10.1(我笔记 ...

  3. TensorRT(2)-基本使用:mnist手写体识别

    结合 tensorRT官方给出的一个例程,介绍tensorRT的使用. 这个例程是mnist手写体识别.例程位于目录: /usr/src/tensorrt/samples/sampleMNIST 文件 ...

  4. R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

    本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...

  5. 【人工智能项目】MNIST手写体识别实验及分析

    [人工智能项目]MNIST数据集实验报告 这是之前接的小作业,现在分享出来,给大家以学习!!! [人工智能项目]MNIST手写体识别实验及分析 1.实验内容简述 1.1 实验环境 本实验采用的软硬件实 ...

  6. 2021年人工神经网络第四次作业 - 第二题MNIST手写体识别

    简 介: ※MNIST数据集合是深度学习基础训练数据集合.改数据集合可以使用稠密前馈神经网络训练,也可以使用CNN.本文采用了单隐层BP网络和LeNet网络对于MNIST数据集合进行测试.实验结果标明 ...

  7. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  8. python神经网络案例——FC全连接神经网络实现mnist手写体识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 FC全连接神经网络的理论教程参考 http://blog.csdn.net/luanpeng825485697/article/details ...

  9. mnist手写体识别中用到的TensorFlow API总结

    声明:本文通过CNN实现mnist例子总结了TensorFlow 1.12的相关API.代码来源于<Learning TensorFlow>这本书,API查阅了TensorFlow官网AP ...

  10. 基于keras的mnist手写体识别程序

    大家好 我是来自河北大学 心电组的一名研一的学生,本篇文章是我对mnist识别学习的认识和分享. 本文主要用来给想要用keras搭建网络识别mnist的同学一个引导. 有错误的地方请大家指正 我会虚心 ...

最新文章

  1. 使用堆内内存HeapByteBuffer的注意事项
  2. 给eth0增加一个IP
  3. 使用 .toLocaleString() 轻松实现多国语言价格数字格式化
  4. ConcurrentLinkedQueue的实现原理和源码分析
  5. SAP CRM呼叫中心里confirm按钮的实现逻辑
  6. sqlite字段是否存在_【漏洞预警】Linux内核存在本地提权漏洞(CVE20198912)
  7. 精通Quartz-入门-Job
  8. ip、url威胁情报库(开源)
  9. oracle监听的动态注册和静态注册
  10. mfc怎么获取进程的线程数_2020年大厂喜欢这样问线程安全,这些知识点我整理好了
  11. 通过调用外部exe的方法实现c#调用java
  12. LeetCode(160): Intersection of Two Linked Lists
  13. cc1: all warnings being treated as errors
  14. XenCenter 创建 New VM
  15. hp 1020 无线打印服务器,HP1020plus无线打印
  16. ios android 录音格式,Audio模块录音格式汇总(aac、mp3)
  17. 三毛的老家:4月中旬了还在中雪!
  18. 矩阵相乘求导(转载)
  19. 【产品】固定成本、可变成本、沉没成本和机会成本
  20. Shaolin(map||set)

热门文章

  1. 免安装版的Mysql安装与配置——详细教程
  2. linux ansys内存不够,ANSYS 硬件配置建议
  3. maxscale mysql 主从_orchestrator+maxscale+mysql5.7GTID主从切换测试过程
  4. C++ gflags示例
  5. c语言对fpga编程,利用C语言对FPGA计算解决方案进行编程方法介绍
  6. aplay 源码分析
  7. 达梦数据库、表字段创建索引或删除索引,增加表字段、修改字段类型或长度、修改注释sql语句
  8. 声卡接口 LINE_IN、MIC_IN、LINE_OUT
  9. 台式机dp接口_常见视频接口图示及说明
  10. 【IT项目管理】第1章 走进IT项目管理