文章目录

  • 1. 问题引入
  • 2. 运行报错
  • 3. 代码
  • 4. 分析原因
  • 5.解决办法
  • 6. 完整代码
  • 7. 参考文献

1. 问题引入

今天在使用pytorch训练一个模型的,数据集的读取是使用pytorch自带的函数来进行读取和预处理的,网络使用的是自定义的CNN,然后在运行的时候出现了如标题所示的这种小错误。

2. 运行报错

如下所示:

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5], but got 2-dimensional input of size [32, 784] instead

3. 代码

首先是我自己自定义的CNN网络如下所示:

class MNIST_Model(nn.Module):def __init__(self, n_in):super(MNIST_Model, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=n_in,out_channels=32,kernel_size=(5, 5),padding=2,stride=1),)self.maxp1 = nn.MaxPool2d(kernel_size=(2, 2))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(5, 5),padding=0,stride=1),)self.maxp2 = nn.MaxPool2d(kernel_size=(2, 2))self.fc1 = nn.Sequential(nn.Linear(in_features=64 * 5 * 5, out_features=200)  # Mnist)self.fc2 = nn.Sequential(nn.Linear(in_features=200, out_features=10),nn.ReLU())def forward(self, x):x = self.conv1(x)x = self.maxp1(x)x = self.conv2(x)x = self.maxp2(x)x = x.contiguous().view(x.size(0), -1)x = self.fc1(x)x = self.fc2(x)return x

然后是在训练模型的代码

#实例化网络,只考虑使用CPU
model = model.MNIST_Model(1)
net = model.to(device)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#momentum:动量因子有什么用处?
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)#开始训练 先定义存储损失函数和准确率的数组
losses = []
acces = []
#测试用
eval_losses = []
eval_acces = []for epoch in range(nums_epoches):#每次训练先清零train_loss = 0train_acc = 0#将模型设置为训练模式model.train()#动态学习率if epoch%5 == 0:optimizer.param_groups[0]['lr'] *= 0.1for img,label in train_loader:#前向传播,将图片数据传入模型中# out输出10维,分别是各数字的概率,即每个类别的得分out = model(img)#这里注意参数out是64*10,label是一维的64loss = criterion(out,label)#反向传播#optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0optimizer.zero_grad()loss.backward()#这个方法会更新所有的参数,一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数optimizer.step()#记录误差 train_loss += loss.item()#计算分类的准确率,找到概率最大的下标_,pred = out.max(1)num_correct = (pred == label).sum().item()#记录标签正确的个数acc = num_correct/img.shape[0]train_acc += acclosses.append(train_loss/len(train_loader))acces.append(train_acc/len(train_loader))eval_loss = 0eval_acc = 0model.eval()for img,label in test_loader:img = img.view(img.size(0),-1)out = model(img)loss = criterion(out,label)optimizer.zero_grad()loss.backward()optimizer.step()eval_loss += loss.item()_,pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct/img.shape[0]eval_acc += acceval_losses.append(eval_loss/len(test_loader))eval_acces.append(eval_acc/len(test_loader))print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'.format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),eval_loss/len(test_loader),eval_acc/len(test_loader)))

4. 分析原因

定位出错位置

Traceback (most recent call last):File "train.py", line 73, in <module>out = model(img)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/gzdx/wyf/PARAD/model.py", line 48, in forwardx = self.conv1(x)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/container.py", line 119, in forwardinput = module(input)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 399, in forwardreturn self._conv_forward(input, self.weight, self.bias)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 396, in _conv_forwardself.padding, self.dilation, self.groups)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5], but got 2-dimensional input of size [32, 784] instead

可以看到这句提示,大致就是我们传入的数据输入到CNN网络,然后由于维度不同导致的。因为我们输入的是四维,但是得到的却是二维。

  File "train.py", line 73, in <module>out = model(img)

5.解决办法

对于这种问题网上给出了很多中不同的方案,这个哦个人也是参考我网上别人给出的一点想法然后自己修改了下,错误就解决了,如下所示:

for i,data in enumerate(train_loader):#前向传播,将图片数据传入模型中# out输出10维,分别是各数字的概率,即每个类别的得分inputs, labels = datainputs,labels = data[0].to(device), data[1].to(device)# inputs torch.Size([32, 1, 28, 28])out = model(inputs)

解决办法也是很简单,就是将上面训练开始阶段将数据按照这种读取方式来赋值,然后在传入到model里面就不会出现上面那种错误了。

6. 完整代码

import numpy as np
import model
import torch#导入PyTorch内置的mnist数据
from torchvision.datasets import mnist#导入预处理模块
from torchvision import transforms
from torch.utils.data import DataLoader#导入神经网络工具
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim#定义后面要用到的超参数
train_batch_size = 32
test_batch_size = 32#学习率与训练次数
learning_rate = 0.01
nums_epoches = 50#优化器的时候使用的参数
lr = 0.1
momentum = 0.5#用compose来定意预处理函数
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#下载数据,在工程文件夹里新建一个data文件夹储存下载的数据
train_dataset = mnist.MNIST('./data', train=True, transform=transform, target_transform=None, download=False)
test_dataset = mnist.MNIST('./data', train=False, transform=transform, target_transform=None, download=False)#数据加载器,组合数据集和采样器,并在数据集上提供单进程或多进程迭代器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#实例化网络,只考虑使用CPU
model = model.MNIST_Model(1)
net = model.to(device)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#momentum:动量因子有什么用处?
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)#开始训练 先定义存储损失函数和准确率的数组
losses = []
acces = []
#测试用
eval_losses = []
eval_acces = []for epoch in range(nums_epoches):#每次训练先清零train_loss = 0train_acc = 0#将模型设置为训练模式model.train()#动态学习率if epoch%5 == 0:optimizer.param_groups[0]['lr'] *= 0.1for i,data in enumerate(train_loader):#前向传播,将图片数据传入模型中# out输出10维,分别是各数字的概率,即每个类别的得分inputs, labels = datainputs,labels = data[0].to(device), data[1].to(device)out = model(inputs)#这里注意参数out是64*10,label是一维的64loss = criterion(out,labels)#反向传播#optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0optimizer.zero_grad()loss.backward()#这个方法会更新所有的参数,一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数optimizer.step()#记录误差 train_loss += loss.item()#计算分类的准确率,找到概率最大的下标_,pred = out.max(1)num_correct = (pred == labels).sum().item() #记录标签正确的个数acc = num_correct/inputs.shape[0]train_acc += acclosses.append(train_loss/len(train_loader))acces.append(train_acc/len(train_loader))print('Finished Training') # 保存模型PATH = './model/mnist_net.pth'torch.save(net.state_dict(), PATH)eval_loss = 0eval_acc = 0model.eval()for i,data in enumerate(test_loader):inputs, labels = datainputs,labels = data[0].to(device), data[1].to(device)out = model(inputs)loss = criterion(out,labels)optimizer.zero_grad()loss.backward()optimizer.step()eval_loss += loss.item()_,pred = out.max(1)num_correct = (pred == labels).sum().item()acc = num_correct/inputs.shape[0]eval_acc += acceval_losses.append(eval_loss/len(test_loader))eval_acces.append(eval_acc/len(test_loader))print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'.format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),eval_loss/len(test_loader),eval_acc/len(test_loader)))

7. 参考文献

1.pytorch学习笔记—搭建CNN识别MNIST

2.使用Pytorch框架的CNN网络实现手写数字(MNIST)识别

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5]相关推荐

  1. RuntimeError: Expected 4-dimensional input for 4-dimensional weight, but got 3-dimensional input

    1. 错误分析 错误: RuntimeError: Expected 4-dimensional input for 4-dimensional weight, but got 3-dimension ...

  2. torchserve 错误:RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tenso

    RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument ...

  3. output = input.matmul(weight.t()) RuntimeError: expected scalar type Long but found Float 错误解决

    在使用pytorch的nn.Linear时出现错误 RuntimeError: expected scalar type Long but found Float 这里报错的原因是我的输入是Longt ...

  4. 常见报错:RuntimeError: expected scalar type Long but found Float

    RuntimeError: expected scalar type Long but found Float 这是一个非常常见的报错,我已经遇到过这个报错很多次了,但是之前没有仔细研究过,今天好好好 ...

  5. 解决RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cp

    今天在把.pt文件转ONNX文件时,遇到此错误. 报错 RuntimeError: Expected all tensors to be on the same device, but found a ...

  6. RuntimeError: Expected object of backend CUDA but got backend CPU for argument

    RuntimeError: Expected object of backend CUDA but got backend CPU for argument #4 'mat1' 原因:变量没有加cud ...

  7. RuntimeError: Expected object of device type cuda but got device type cpu for argument pytorch数据位置

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'target' i ...

  8. RuntimeError: Expected object of backend CUDA but got backend CPU for argument #4 'mat1'

    RuntimeError: Expected object of backend CUDA but got backend CPU for argument #4 'mat1' 原因:变量没有加cud ...

  9. 异常解决(一)-- RuntimeError: expected device cpu but got device cuda:0

    最近在编写深度学习的相关代码,基于PyTorch,运行程序的时候,报错,报错内容如下所示: RuntimeError: expected device cpu but got device cuda: ...

  10. RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'target'

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'target' i ...

最新文章

  1. 「欧拉定理」学习笔记(费马小定理)
  2. python批量下载文件-Python实现批量下载文件
  3. celery的log如何传递给django,由django管理
  4. 数据库系统(五)——数据库设计
  5. Java Hashtable size()方法与示例
  6. 从零开始学视觉Transformer(2):图像与Transformer基础
  7. python rowcount_PyQt(Python+Qt)学习随笔:QTableWidget的currentItem、rowCount、columnCount等部件状态属性访问方法...
  8. mysql怎么禁止远程连接_mysql如何设置禁止远程连接
  9. php查询记录是否存在,php – 如果记录存在,我可以更新记录,如果不存在,可以在单个查询中更新多行吗?...
  10. mongodb模糊查询_我叫Mongo,收了「查询基础篇」,值得你拥有
  11. oracle中主键的建立,oracle 建立主键与索引
  12. ETH-TRUNK链路原理和实验
  13. QT5.9 for 安卓开发 环境配置
  14. app软件怎么开发 盘点3种app制作方式
  15. 夏普Sharp MX-M2658N 一体机驱动
  16. 华盛顿大学华人团队进入微软Imagine Cup总决赛!
  17. 数显之家快讯:【SHIO世硕心语】SHIO世硕科技企业文化宣言知多少?
  18. 安装AmaterasUML插件和GEF插件的详细步骤
  19. Nginx添加腾讯安全HTTPS证书
  20. 抽象类(abstract class)和接口(interface)

热门文章

  1. 静态HTML网页设计作品——食品餐饮行业网站模板(10页) HTML+CSS+JavaScript 学生DW网页设计作业成品 美食生鲜零食网页设计
  2. HTML+PHP+Mysql登录注册页面
  3. f2fs系列文章gc
  4. InnoDB行记录格式
  5. python画猫咪老师_夏目友人帐 | 绘画小白怎样用水彩画一只圆滚滚的猫咪老师?...
  6. android 微信登陆功能,Android 实现微信登录详解
  7. PPT之背景与标题搭配
  8. 微信登录报错40125和-6签名秘钥问题解决方案
  9. 台式电脑怎么组装步骤_台式电脑组装教程图解,手把手教您组装(零基础也能搞定)...
  10. arcgis取消投影_【坐标系杂谈】投影后的数据如何去除投影?