文章目录

  • 前言
  • dataset
  • dataloader之collate_fn
  • 应用情形

前言

import torch.utils.data as tud

collate_fn:即用于collate的function,用于整理数据的函数。
说到整理数据,你当然也要会用tud.Dataset,因为这个你定义好后,才会产生数据嘛,产生了数据我们才能整理数据嘛,而整理数据我们使用collate_fn

dataset

我们必须先看看tud.Dataset如何使用,以一个例子为例:

class mydataset(tud.Dataset):def __init__(self,data):self.data=datadef __len__(self):#必须重写return len(self.data)def __getitem__(self,idx):#必须重写return self.data[idx]
#构造训练数据
a=np.random.rand(4,3)#4个数据,每一个数据是一个向量。
print(a)

#制作dataset
dataset=mydataset(a)
len(dataset)#调用了你上面定义的def __len__()那个函数
#4
dataset[0]#调用了你上面定义的def __getitem__()那个函数,传入的idx=0,也就是取第0个数据。
#array([0.56998216, 0.72663738, 0.3706266 ])

dataloader之collate_fn

dataloader=tud.DataLoader(dataset,batch_size=2)

batch_size=2即一个batch里面会有2个数据。我们以第1个batch为例,tud.DataLoader会根据dataset取出前2个数据,然后弄成一个列表,如下:

batch=[dataset[0],dataset[1]]
batch

[array([0.56998216, 0.72663738, 0.3706266 ]),
array([0.3403586 , 0.13931333, 0.71030221])]

然后将上面这个batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值。如果你不指定这个函数是什么,那么会调用pytorch内部的collate_fn

也就是说,我们如果自己要指定这个函数,collate_fn应该定义成下面这个样子。

def my_collate(batch):#batch上面说过,是dataloader传进来的。***#你自己定义怎么整理数据。下面会说。real_batch=***return real_batch

那么pytorch内部默认的collate_fn函数长什么样呢?我们先观察下面的例子:

it=iter(dataloader)
nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)

tensor([[0.5700, 0.7266, 0.3706],
[0.3404, 0.1393, 0.7103]], dtype=torch.float64)

上面这个返回的结果就是real_batch。也就是collate_fn函数的返回值!!也就是说collate_fn将batch变成了上面的real_batch。

我们重新写一遍

batch:
[array([0.56998216, 0.72663738, 0.3706266 ]),
array([0.3403586 , 0.13931333, 0.71030221])]
real_batch:
tensor([[0.5700, 0.7266, 0.3706],
[0.3404, 0.1393, 0.7103]], dtype=torch.float64)

将batch变成上述real_batch很容易呀,就是把一个列表,变成了矩阵,我们也会!我们下面就来自己写一个collate_fn实现这个功能。

def my_collate(batch):real_batch=np.array(batch)real_batch=torch.from_numpy(real_batch)return real_batch
dataloader2=tud.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader2)
nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)

tensor([[0.5700, 0.7266, 0.3706],
[0.3404, 0.1393, 0.7103]], dtype=torch.float64)

这不就和默认的collate_fn的输出结果一样了嘛!

应用情形

通常,我们并不需要使用这个函数,因为pytorch内部有一个默认的。但是,如果你的数据不规整,使用默认的会报错。例如下面的数据。
假设我们还是4个输入,但是维度不固定的。之前我们是每一个数据的维度都为3。

a=[[1,2],[3,4,5],[1],[3,4,9]]
dataset=mydataset(a,b)
dataloader=tud.DataLoader(dataset,batch_size=2)
it=iter(dataloader)
nex=next(it)
nex

使用默认的collate_fn,直接报错,要求相同维度。

这个时候,我们可以使用自己的collate_fn,避免报错。

不过话说回来,我个人感受是:

在这里避免报错好像也没有什么用,因为大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘,所以:这里不报错,后面还是报错。如果后面解决这个问题的方法是:在不足维度上进行补0操作,那么我们为什么不在建立dataset之前先补好呢?所以,collate_fn这个东西的应用场景还是有限的。不过,明白其原理总是好事。


完结撒花


pytorch之深入理解collate_fn相关推荐

  1. 利用pytorch来深入理解CELoss、BCELoss和NLLLoss之间的关系

    利用pytorch来深入理解CELoss.BCELoss和NLLLoss之间的关系 损失函数为为计算预测值与真实值之间差异的函数,损失函数越小,预测值与真实值间的差异越小,证明网络效果越好.对于神经网 ...

  2. 基于pytorch实现图像分类——理解自动求导、计算图、静态图、动态图、pytorch入门

    1. pytorch入门 什么是PYTORCH? 这是一个基于Python的科学计算软件包,针对两组受众: 替代NumPy以使用GPU的功能 提供最大灵活性和速度的深度学习研究平台 1.1 开发环境 ...

  3. Pytorch之Dataloader参数collate_fn研究

    前言 之前看了不到pytorch代码,对Dataloader的大部分参数都比较了解,今天看代码时,发现了一个参数collate_fn ,之前论文代码没怎么见过,也就自动忽略了,今天既然遇到了,就突然来 ...

  4. Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)

    在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面简称seq_len),接着便可以在自定义的data_generator内进行个性化 ...

  5. pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

    文章目录 前言 1. reshape() 2. view() ① 1 阶变高阶 1 阶变 2 阶 1 阶变 3 阶 1 阶变 4 阶 1 阶变 m 阶 ② 2 阶变 m 阶 ③ 3 阶变 m 阶 ④ ...

  6. PyTorch之深入理解list、ModuleList和Sequential。

    文章目录 list Sequential ModuleList 总结 import torch import torch.nn as nn list class a(nn.Module):def __ ...

  7. Pytorch之深入理解torch.nn.Parameter()

    先看一段代码: import torch import torch.nn as nn a=torch.tensor([1,2],dtype=torch.float32) print(a) print( ...

  8. lstm 输入数据维度_理解Pytorch中LSTM的输入输出参数含义

    本文不会介绍LSTM的原理,具体可看如下两篇文章 Understanding LSTM Networks DeepLearning.ai学习笔记(五)序列模型 -- week1 循环序列模型 1.举个 ...

  9. pytorch 深入理解 tensor.scatter_ ()用法

    pytorch 深入理解 tensor.scatter_ ()用法 在 pytorch 库下理解 torch.tensor.scatter()的用法.作者在网上搜索了很多方法,最后还是觉得自己写一篇更 ...

最新文章

  1. VS2015占内存大吗?_手游越来越占内存,80%的手机安装一个大游戏就满了,厂商肉搏...
  2. golang 切片排序
  3. 苹果cms10的php.ini目录列表,[苹果cmsV10]常见问题整理官方版
  4. 技术分享|单元测试推广与实战-在全新的DDD架构上进行单元测试
  5. gitee提交代码_git 版本控制,github和gitee
  6. java8 lambda表达式实现自定义用户组件,Don't Repeat Yourself
  7. java遮罩层_页面遮罩层 - javaalex的个人空间 - OSCHINA - 中文开源技术交流社区
  8. python基础知识-Python语言基础知识总结
  9. Mac Air USB接口 失效/不起作用 的修复方式
  10. 大连IT产业解析(1布局篇)
  11. mantis 邮件配置 linux,Linux系统 mantis 1.0.6工单系统配置安装
  12. macOS 13 Ventura系统自动开机在哪设置?
  13. 阿里大数据之路 总述
  14. 简单脱壳教程笔记(2)---手脱UPX壳(1)
  15. Vue学习笔记三(组件间通信)
  16. 差分数组分析详解+例题
  17. 连技术大拿都偷偷在用的偷懒神器Lombok
  18. Android“应用未安装”的解决办法
  19. 每日一句英语,看我能够坚持多久
  20. 10.4.3 编程实例-太阳系动画

热门文章

  1. 虎虎生威且看今朝 | 数据派优秀志愿者风采展
  2. 从引力波探测到RNA测序,AI如何加速科学发现
  3. 为什么 Pi 会出现在正态分布的方程中?
  4. 如何撰写好一篇论文?密歇根Andrew教授这篇《撰写高影响力论文指南》为你细致讲解论文写作,附视频与pdf...
  5. 学习人必看!空军老兵自学编程,仅隔一年成为国土安全部的数据库分析师
  6. 独家 | 2020年22个广泛使用的数据科学与机器学习工具(附链接)
  7. WAIC开发者日倒计时两天,收藏好这份完整日程
  8. 公示 | 首届中国智能心电大赛初赛结果
  9. 找到反例!博士后数学家推翻困扰数学界80多年的单位猜想
  10. 逆天了:Nature一篇论文57000位作者,更厉害的是,大多数作者都是游戏玩家