RuntimeError: each element in list of batch should be of equal size
RuntimeError: each element in list of batch should be of equal size
示例代码:
import os
import re
from torch.utils.data import Dataset, DataLoaderdata_base_path = r'./aclImdb/'# 1.定义token的方法
def tokenize(test):filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@','\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]text = re.sub("<.*?>", " ", test, flags=re.S)text = re.sub("|".join(filters), " ", test, flags=re.S)return [i.strip() for i in text.split()]# 2.准备dataset
class ImdbDataset(Dataset):def __init__(self, mode):super().__init__()if mode == "train":text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]else:text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]self.total_file_path_list = []for i in text_path:self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])def __getitem__(self, item):cur_path = self.total_file_path_list[item]cur_filename = os.path.basename(cur_path)label = int(cur_filename.split("_")[-1].split(".")[0]) - 1 # 处理标题,获取标签label,转化为从[0-9]text = tokenize(open(cur_path).read().strip()) # 直接按照空格进行分词return label, textdef __len__(self):return len(self.total_file_path_list)# 3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)# 4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):print("idx:", idx)print("label:", label)print("text:", text)break
运行结果:
报错原因:
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True),发现是这行代码导致的错误,如果把batch_size=2改为batch_size=1时就不再报错了,运行结果如下:
但是如果想让batch_size=2时,这个错误该如何解决呢?
解决方法如下:
出现问题的原因在于Dataloader
中的参数collate_fn
collate_fn
的默认值为torch自定义的default_collate
,collate_fn
的作用就是对每个batch进行处理,而默认的default_collate
处理出错。
解决问题的思路:
- 手段1:考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误
- 手段2:考虑自定义一个
collate_fn
,观察结果
这里使用方式2,自定义一个collate_fn
,然后观察结果:
def collate_fn(batch):# batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果batch = list(zip(*batch))labels = torch.tensor(batch[0], dtype=torch.int32)texts = batch[1]del batchreturn labels, texts
全部代码:
import os
import re
import torch
from torch.utils.data import Dataset, DataLoaderdata_base_path = r'./aclImdb/'# 1.定义token的方法
def tokenize(test):filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@','\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]text = re.sub("<.*?>", " ", test, flags=re.S)text = re.sub("|".join(filters), " ", test, flags=re.S)return [i.strip() for i in text.split()]# 2.准备dataset
class ImdbDataset(Dataset):def __init__(self, mode):super().__init__()if mode == "train":text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]else:text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]self.total_file_path_list = []for i in text_path:self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])def __getitem__(self, item):cur_path = self.total_file_path_list[item]cur_filename = os.path.basename(cur_path)label = int(cur_filename.split("_")[-1].split(".")[0]) - 1 # 处理标题,获取标签label,转化为从[0-9]text = tokenize(open(cur_path).read().strip()) # 直接按照空格进行分词return label, textdef __len__(self):return len(self.total_file_path_list)def collate_fn(batch):# batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果batch = list(zip(*batch))labels = torch.tensor(batch[0], dtype=torch.int32)texts = batch[1]del batchreturn labels, texts# 3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)# 4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):print("idx:", idx)print("label:", label)print("text:", text)break
运行效果:
RuntimeError: each element in list of batch should be of equal size相关推荐
- RuntimeError: stack expects each tensor to be equal size, but got xxx at entry 0 at entry 1
今日做模型训练,Pytorch在加载数据时遇到如下错误: Epoch 1/800: 31%|███ | 4/13 [00:03<00:07, 1.27img/s, loss (batch)=1. ...
- RuntimeError: stack expects each tensor to be equal size
在调试densenet进行分类任务的代码时,在图像预处理的过程中遇到下列错误: RuntimeError: stack expects each tensor to be equal size, bu ...
- DataLoader问题解决:RuntimeError: stack expects each tensor to be equal size, but got [3, 200, 200]entry1
最近,在数据集处理并载入DataLoader进行训练的时候出现了问题: RuntimeError: stack expects each tensor to be equal size, but go ...
- 记录解决RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 27 but got size
问题描述 在做目标检测服务过程中,将yolov7模型通过flask打包成预测服务API,此次训练的图像输入大小是1280,输入预测图片是如果图像大于1280则预测成功,小于1280则报RuntimeE ...
- RuntimeError: stack expects each tensor to be equal size, but got [8] at entry 0 and [2] at entry 2
最近在调试pytorch代码的时候遇到如下问题,由于他报错的地方不是在我们自己写的代码,而是在pytorch的包里,所以一开始就一头雾水. 在查阅了资料以后http://www.zzvips. ...
- 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式
文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...
- pytorch __getitem__ 返回值
在pytorch中若是使用自定义数据集,需要定义Dataset类,并覆盖父类的__len__和__getitem__函数 举个例子,返回常规的数据对x, y 也可以是多个x,y 比如小样本学习中需要q ...
- NLP第三周(中文分词,新词发现,tfidf)(1)
正向最大匹配 其主要是目的是将一句话分成进行词语的划分,相当于看看这句话由哪些词语组成,最完美的解决方案是,我会准备一个词库,然后我输入进去一句话,刚好我用我词库里面的词语把这句话分成一个一个词,一个 ...
- QML Image Element
QML Image Element The Image element displays an image in a declarative user interface More... Image元 ...
最新文章
- 使用牛刀云开发微信小程序(问题集锦)
- C#两大知名Redis客户端连接哨兵集群的姿势
- python基础(正则表达式)
- 常见的布局实现,以及响应式布局技巧。
- CCF 201403-2 窗口
- jquery.serialize
- Servlet中的请求转发
- 关于字节对齐(关于align)
- python 微信小程序制作教程_微信小程序从零开始开发步骤(一)
- 苹果电脑怎么读取ntfs磁盘?有哪些可以读取苹果电脑硬盘的软件?
- dtu连接虚拟服务器,DTU连接HTTP网页
- HTML选择器的学习
- 软文营销有什么效果,主要作用是什么?
- Freemaker判断是否为空
- opencv-------高斯滤波
- 推测的删除锁(Speculative Lock Elision):实现高并发多线程执行
- Kali Linux渗透测试 128 拒绝服务--TearDrop 攻击
- rem、em、px、rpx、vw、vh、%等
- 侃谈移动端音视频发展与现状
- ad软件 pcb如何走线过孔_PCB走线和过孔的过流能力
热门文章
- Python 捕获警告
- Uber 前无人驾驶工程师告诉你,国内无人驾驶之路还要走多久?
- “数学不好,干啥都不行!”骨灰级程序员:其实你们都是瞎努力
- 《评人工智能如何走向新阶段》后记
- ICCV 2019 | 加一个任务路由让数百个任务同时跑起来,怎么做到?
- 投稿2877篇,EMNLP 2019公布4篇最佳论文
- 如何学习SVM?怎么改进实现SVM算法程序?答案来了
- 性能比GPU高100倍!华人教授研发全球首个可编程忆阻器AI计算机
- “搞垮” 微博服务器?每天上亿条用户推送是如何做到的
- 1024程序员节,你是我们要找的那条锦鲤吗?