RuntimeError: each element in list of batch should be of equal size

示例代码:

import os
import re
from torch.utils.data import Dataset, DataLoaderdata_base_path = r'./aclImdb/'#  1.定义token的方法
def tokenize(test):filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@','\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]text = re.sub("<.*?>", " ", test, flags=re.S)text = re.sub("|".join(filters), " ", test, flags=re.S)return [i.strip() for i in text.split()]#  2.准备dataset
class ImdbDataset(Dataset):def __init__(self, mode):super().__init__()if mode == "train":text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]else:text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]self.total_file_path_list = []for i in text_path:self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])def __getitem__(self, item):cur_path = self.total_file_path_list[item]cur_filename = os.path.basename(cur_path)label = int(cur_filename.split("_")[-1].split(".")[0]) - 1  # 处理标题,获取标签label,转化为从[0-9]text = tokenize(open(cur_path).read().strip())  # 直接按照空格进行分词return label, textdef __len__(self):return len(self.total_file_path_list)#  3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)#  4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):print("idx:", idx)print("label:", label)print("text:", text)break

运行结果:

报错原因:

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True),发现是这行代码导致的错误,如果把batch_size=2改为batch_size=1时就不再报错了,运行结果如下:

但是如果想让batch_size=2时,这个错误该如何解决呢?

解决方法如下:

出现问题的原因在于Dataloader中的参数collate_fn

collate_fn的默认值为torch自定义的default_collate,collate_fn的作用就是对每个batch进行处理,而默认的default_collate处理出错。

解决问题的思路:

  • 手段1:考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误
  • 手段2:考虑自定义一个collate_fn,观察结果

这里使用方式2,自定义一个collate_fn,然后观察结果:

def collate_fn(batch):#  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果batch = list(zip(*batch))labels = torch.tensor(batch[0], dtype=torch.int32)texts = batch[1]del batchreturn labels, texts

全部代码:

import os
import re
import torch
from torch.utils.data import Dataset, DataLoaderdata_base_path = r'./aclImdb/'#  1.定义token的方法
def tokenize(test):filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@','\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]text = re.sub("<.*?>", " ", test, flags=re.S)text = re.sub("|".join(filters), " ", test, flags=re.S)return [i.strip() for i in text.split()]#  2.准备dataset
class ImdbDataset(Dataset):def __init__(self, mode):super().__init__()if mode == "train":text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]else:text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]self.total_file_path_list = []for i in text_path:self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])def __getitem__(self, item):cur_path = self.total_file_path_list[item]cur_filename = os.path.basename(cur_path)label = int(cur_filename.split("_")[-1].split(".")[0]) - 1  # 处理标题,获取标签label,转化为从[0-9]text = tokenize(open(cur_path).read().strip())  # 直接按照空格进行分词return label, textdef __len__(self):return len(self.total_file_path_list)def collate_fn(batch):#  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果batch = list(zip(*batch))labels = torch.tensor(batch[0], dtype=torch.int32)texts = batch[1]del batchreturn labels, texts#  3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)#  4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):print("idx:", idx)print("label:", label)print("text:", text)break

运行效果:

RuntimeError: each element in list of batch should be of equal size相关推荐

  1. RuntimeError: stack expects each tensor to be equal size, but got xxx at entry 0 at entry 1

    今日做模型训练,Pytorch在加载数据时遇到如下错误: Epoch 1/800: 31%|███ | 4/13 [00:03<00:07, 1.27img/s, loss (batch)=1. ...

  2. RuntimeError: stack expects each tensor to be equal size

    在调试densenet进行分类任务的代码时,在图像预处理的过程中遇到下列错误: RuntimeError: stack expects each tensor to be equal size, bu ...

  3. DataLoader问题解决:RuntimeError: stack expects each tensor to be equal size, but got [3, 200, 200]entry1

    最近,在数据集处理并载入DataLoader进行训练的时候出现了问题: RuntimeError: stack expects each tensor to be equal size, but go ...

  4. 记录解决RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 27 but got size

    问题描述 在做目标检测服务过程中,将yolov7模型通过flask打包成预测服务API,此次训练的图像输入大小是1280,输入预测图片是如果图像大于1280则预测成功,小于1280则报RuntimeE ...

  5. RuntimeError: stack expects each tensor to be equal size, but got [8] at entry 0 and [2] at entry 2

      最近在调试pytorch代码的时候遇到如下问题,由于他报错的地方不是在我们自己写的代码,而是在pytorch的包里,所以一开始就一头雾水.   在查阅了资料以后http://www.zzvips. ...

  6. 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式

    文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...

  7. pytorch __getitem__ 返回值

    在pytorch中若是使用自定义数据集,需要定义Dataset类,并覆盖父类的__len__和__getitem__函数 举个例子,返回常规的数据对x, y 也可以是多个x,y 比如小样本学习中需要q ...

  8. NLP第三周(中文分词,新词发现,tfidf)(1)

    正向最大匹配 其主要是目的是将一句话分成进行词语的划分,相当于看看这句话由哪些词语组成,最完美的解决方案是,我会准备一个词库,然后我输入进去一句话,刚好我用我词库里面的词语把这句话分成一个一个词,一个 ...

  9. QML Image Element

    QML Image Element The Image element displays an image in a declarative user interface More... Image元 ...

最新文章

  1. 使用牛刀云开发微信小程序(问题集锦)
  2. C#两大知名Redis客户端连接哨兵集群的姿势
  3. python基础(正则表达式)
  4. 常见的布局实现,以及响应式布局技巧。
  5. CCF 201403-2 窗口
  6. jquery.serialize
  7. Servlet中的请求转发
  8. 关于字节对齐(关于align)
  9. python 微信小程序制作教程_微信小程序从零开始开发步骤(一)
  10. 苹果电脑怎么读取ntfs磁盘?有哪些可以读取苹果电脑硬盘的软件?
  11. dtu连接虚拟服务器,DTU连接HTTP网页
  12. HTML选择器的学习
  13. 软文营销有什么效果,主要作用是什么?
  14. Freemaker判断是否为空
  15. opencv-------高斯滤波
  16. 推测的删除锁(Speculative Lock Elision):实现高并发多线程执行
  17. Kali Linux渗透测试 128 拒绝服务--TearDrop 攻击
  18. rem、em、px、rpx、vw、vh、%等
  19. 侃谈移动端音视频发展与现状
  20. ad软件 pcb如何走线过孔_PCB走线和过孔的过流能力

热门文章

  1. Python 捕获警告
  2. Uber 前无人驾驶工程师告诉你,国内无人驾驶之路还要走多久?
  3. “数学不好,干啥都不行!”骨灰级程序员:其实你们都是瞎努力
  4. 《评人工智能如何走向新阶段》后记
  5. ICCV 2019 | 加一个任务路由让数百个任务同时跑起来,怎么做到?
  6. 投稿2877篇,EMNLP 2019公布4篇最佳论文
  7. 如何学习SVM?怎么改进实现SVM算法程序?答案来了
  8. 性能比GPU高100倍!华人教授研发全球首个可编程忆阻器AI计算机
  9. “搞垮” 微博服务器?每天上亿条用户推送是如何做到的
  10. 1024程序员节,你是我们要找的那条锦鲤吗?