RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5]
文章目录
- 1. 问题引入
- 2. 运行报错
- 3. 代码
- 4. 分析原因
- 5.解决办法
- 6. 完整代码
- 7. 参考文献
1. 问题引入
今天在使用pytorch训练一个模型的,数据集的读取是使用pytorch自带的函数来进行读取和预处理的,网络使用的是自定义的CNN,然后在运行的时候出现了如标题所示的这种小错误。
2. 运行报错
如下所示:
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5], but got 2-dimensional input of size [32, 784] instead
3. 代码
首先是我自己自定义的CNN网络如下所示:
class MNIST_Model(nn.Module):def __init__(self, n_in):super(MNIST_Model, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=n_in,out_channels=32,kernel_size=(5, 5),padding=2,stride=1),)self.maxp1 = nn.MaxPool2d(kernel_size=(2, 2))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(5, 5),padding=0,stride=1),)self.maxp2 = nn.MaxPool2d(kernel_size=(2, 2))self.fc1 = nn.Sequential(nn.Linear(in_features=64 * 5 * 5, out_features=200) # Mnist)self.fc2 = nn.Sequential(nn.Linear(in_features=200, out_features=10),nn.ReLU())def forward(self, x):x = self.conv1(x)x = self.maxp1(x)x = self.conv2(x)x = self.maxp2(x)x = x.contiguous().view(x.size(0), -1)x = self.fc1(x)x = self.fc2(x)return x
然后是在训练模型的代码
#实例化网络,只考虑使用CPU
model = model.MNIST_Model(1)
net = model.to(device)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#momentum:动量因子有什么用处?
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)#开始训练 先定义存储损失函数和准确率的数组
losses = []
acces = []
#测试用
eval_losses = []
eval_acces = []for epoch in range(nums_epoches):#每次训练先清零train_loss = 0train_acc = 0#将模型设置为训练模式model.train()#动态学习率if epoch%5 == 0:optimizer.param_groups[0]['lr'] *= 0.1for img,label in train_loader:#前向传播,将图片数据传入模型中# out输出10维,分别是各数字的概率,即每个类别的得分out = model(img)#这里注意参数out是64*10,label是一维的64loss = criterion(out,label)#反向传播#optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0optimizer.zero_grad()loss.backward()#这个方法会更新所有的参数,一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数optimizer.step()#记录误差 train_loss += loss.item()#计算分类的准确率,找到概率最大的下标_,pred = out.max(1)num_correct = (pred == label).sum().item()#记录标签正确的个数acc = num_correct/img.shape[0]train_acc += acclosses.append(train_loss/len(train_loader))acces.append(train_acc/len(train_loader))eval_loss = 0eval_acc = 0model.eval()for img,label in test_loader:img = img.view(img.size(0),-1)out = model(img)loss = criterion(out,label)optimizer.zero_grad()loss.backward()optimizer.step()eval_loss += loss.item()_,pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct/img.shape[0]eval_acc += acceval_losses.append(eval_loss/len(test_loader))eval_acces.append(eval_acc/len(test_loader))print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'.format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),eval_loss/len(test_loader),eval_acc/len(test_loader)))
4. 分析原因
定位出错位置
Traceback (most recent call last):File "train.py", line 73, in <module>out = model(img)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/gzdx/wyf/PARAD/model.py", line 48, in forwardx = self.conv1(x)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/container.py", line 119, in forwardinput = module(input)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 399, in forwardreturn self._conv_forward(input, self.weight, self.bias)File "/home/gzdx/anaconda3/envs/Torch/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 396, in _conv_forwardself.padding, self.dilation, self.groups)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5], but got 2-dimensional input of size [32, 784] instead
可以看到这句提示,大致就是我们传入的数据输入到CNN网络,然后由于维度不同导致的。因为我们输入的是四维,但是得到的却是二维。
File "train.py", line 73, in <module>out = model(img)
5.解决办法
对于这种问题网上给出了很多中不同的方案,这个哦个人也是参考我网上别人给出的一点想法然后自己修改了下,错误就解决了,如下所示:
for i,data in enumerate(train_loader):#前向传播,将图片数据传入模型中# out输出10维,分别是各数字的概率,即每个类别的得分inputs, labels = datainputs,labels = data[0].to(device), data[1].to(device)# inputs torch.Size([32, 1, 28, 28])out = model(inputs)
解决办法也是很简单,就是将上面训练开始阶段将数据按照这种读取方式来赋值,然后在传入到model里面就不会出现上面那种错误了。
6. 完整代码
import numpy as np
import model
import torch#导入PyTorch内置的mnist数据
from torchvision.datasets import mnist#导入预处理模块
from torchvision import transforms
from torch.utils.data import DataLoader#导入神经网络工具
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim#定义后面要用到的超参数
train_batch_size = 32
test_batch_size = 32#学习率与训练次数
learning_rate = 0.01
nums_epoches = 50#优化器的时候使用的参数
lr = 0.1
momentum = 0.5#用compose来定意预处理函数
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#下载数据,在工程文件夹里新建一个data文件夹储存下载的数据
train_dataset = mnist.MNIST('./data', train=True, transform=transform, target_transform=None, download=False)
test_dataset = mnist.MNIST('./data', train=False, transform=transform, target_transform=None, download=False)#数据加载器,组合数据集和采样器,并在数据集上提供单进程或多进程迭代器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#实例化网络,只考虑使用CPU
model = model.MNIST_Model(1)
net = model.to(device)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#momentum:动量因子有什么用处?
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)#开始训练 先定义存储损失函数和准确率的数组
losses = []
acces = []
#测试用
eval_losses = []
eval_acces = []for epoch in range(nums_epoches):#每次训练先清零train_loss = 0train_acc = 0#将模型设置为训练模式model.train()#动态学习率if epoch%5 == 0:optimizer.param_groups[0]['lr'] *= 0.1for i,data in enumerate(train_loader):#前向传播,将图片数据传入模型中# out输出10维,分别是各数字的概率,即每个类别的得分inputs, labels = datainputs,labels = data[0].to(device), data[1].to(device)out = model(inputs)#这里注意参数out是64*10,label是一维的64loss = criterion(out,labels)#反向传播#optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0optimizer.zero_grad()loss.backward()#这个方法会更新所有的参数,一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数optimizer.step()#记录误差 train_loss += loss.item()#计算分类的准确率,找到概率最大的下标_,pred = out.max(1)num_correct = (pred == labels).sum().item() #记录标签正确的个数acc = num_correct/inputs.shape[0]train_acc += acclosses.append(train_loss/len(train_loader))acces.append(train_acc/len(train_loader))print('Finished Training') # 保存模型PATH = './model/mnist_net.pth'torch.save(net.state_dict(), PATH)eval_loss = 0eval_acc = 0model.eval()for i,data in enumerate(test_loader):inputs, labels = datainputs,labels = data[0].to(device), data[1].to(device)out = model(inputs)loss = criterion(out,labels)optimizer.zero_grad()loss.backward()optimizer.step()eval_loss += loss.item()_,pred = out.max(1)num_correct = (pred == labels).sum().item()acc = num_correct/inputs.shape[0]eval_acc += acceval_losses.append(eval_loss/len(test_loader))eval_acces.append(eval_acc/len(test_loader))print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'.format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),eval_loss/len(test_loader),eval_acc/len(test_loader)))
7. 参考文献
1.pytorch学习笔记—搭建CNN识别MNIST
2.使用Pytorch框架的CNN网络实现手写数字(MNIST)识别
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 5, 5]相关推荐
- RuntimeError: Expected 4-dimensional input for 4-dimensional weight, but got 3-dimensional input
1. 错误分析 错误: RuntimeError: Expected 4-dimensional input for 4-dimensional weight, but got 3-dimension ...
- torchserve 错误:RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tenso
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument ...
- output = input.matmul(weight.t()) RuntimeError: expected scalar type Long but found Float 错误解决
在使用pytorch的nn.Linear时出现错误 RuntimeError: expected scalar type Long but found Float 这里报错的原因是我的输入是Longt ...
- 常见报错:RuntimeError: expected scalar type Long but found Float
RuntimeError: expected scalar type Long but found Float 这是一个非常常见的报错,我已经遇到过这个报错很多次了,但是之前没有仔细研究过,今天好好好 ...
- 解决RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cp
今天在把.pt文件转ONNX文件时,遇到此错误. 报错 RuntimeError: Expected all tensors to be on the same device, but found a ...
- RuntimeError: Expected object of backend CUDA but got backend CPU for argument
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #4 'mat1' 原因:变量没有加cud ...
- RuntimeError: Expected object of device type cuda but got device type cpu for argument pytorch数据位置
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'target' i ...
- RuntimeError: Expected object of backend CUDA but got backend CPU for argument #4 'mat1'
RuntimeError: Expected object of backend CUDA but got backend CPU for argument #4 'mat1' 原因:变量没有加cud ...
- 异常解决(一)-- RuntimeError: expected device cpu but got device cuda:0
最近在编写深度学习的相关代码,基于PyTorch,运行程序的时候,报错,报错内容如下所示: RuntimeError: expected device cpu but got device cuda: ...
- RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'target'
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'target' i ...
最新文章
- 「欧拉定理」学习笔记(费马小定理)
- python批量下载文件-Python实现批量下载文件
- celery的log如何传递给django,由django管理
- 数据库系统(五)——数据库设计
- Java Hashtable size()方法与示例
- 从零开始学视觉Transformer(2):图像与Transformer基础
- python rowcount_PyQt(Python+Qt)学习随笔:QTableWidget的currentItem、rowCount、columnCount等部件状态属性访问方法...
- mysql怎么禁止远程连接_mysql如何设置禁止远程连接
- php查询记录是否存在,php – 如果记录存在,我可以更新记录,如果不存在,可以在单个查询中更新多行吗?...
- mongodb模糊查询_我叫Mongo,收了「查询基础篇」,值得你拥有
- oracle中主键的建立,oracle 建立主键与索引
- ETH-TRUNK链路原理和实验
- QT5.9 for 安卓开发 环境配置
- app软件怎么开发 盘点3种app制作方式
- 夏普Sharp MX-M2658N 一体机驱动
- 华盛顿大学华人团队进入微软Imagine Cup总决赛!
- 数显之家快讯:【SHIO世硕心语】SHIO世硕科技企业文化宣言知多少?
- 安装AmaterasUML插件和GEF插件的详细步骤
- Nginx添加腾讯安全HTTPS证书
- 抽象类(abstract class)和接口(interface)
热门文章
- 静态HTML网页设计作品——食品餐饮行业网站模板(10页) HTML+CSS+JavaScript 学生DW网页设计作业成品 美食生鲜零食网页设计
- HTML+PHP+Mysql登录注册页面
- f2fs系列文章gc
- InnoDB行记录格式
- python画猫咪老师_夏目友人帐 | 绘画小白怎样用水彩画一只圆滚滚的猫咪老师?...
- android 微信登陆功能,Android 实现微信登录详解
- PPT之背景与标题搭配
- 微信登录报错40125和-6签名秘钥问题解决方案
- 台式电脑怎么组装步骤_台式电脑组装教程图解,手把手教您组装(零基础也能搞定)...
- arcgis取消投影_【坐标系杂谈】投影后的数据如何去除投影?