DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(

dataset,

batch_size=1,

shuffle=False,

sampler=None,

batch_sampler=None,

num_workers=0,

collate_fn=,

pin_memory=False,

drop_last=False,

timeout=0,

worker_init_fn=None)

DataLoader在数据集上提供单进程或多进程的迭代器

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

一个测试的例子

import torch

import torch.utils.data as Data

import numpy as np

test = np.array([0,1,2,3,4,5,6,7,8,9,10,11])

inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))

target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))

torch_dataset = Data.TensorDataset(inputing,target)

batch = 3

loader = Data.DataLoader(

dataset=torch_dataset,

batch_size=batch, # 批大小

# 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少

collate_fn=lambda x:(

torch.cat(

[x[i][j].unsqueeze(0) for i in range(len(x))], 0

).unsqueeze(0) for j in range(len(x[0]))

)

)

for (i,j) in loader:

print(i)

print(j)

输出结果:

tensor([[[ 0, 1, 2],

[ 1, 2, 3],

[ 2, 3, 4]]], dtype=torch.int32)

tensor([[[ 0],

[ 1],

[ 2]]], dtype=torch.int32)

tensor([[[ 3, 4, 5],

[ 4, 5, 6],

[ 5, 6, 7]]], dtype=torch.int32)

tensor([[[ 3],

[ 4],

[ 5]]], dtype=torch.int32)

tensor([[[ 6, 7, 8],

[ 7, 8, 9],

[ 8, 9, 10]]], dtype=torch.int32)

tensor([[[ 6],

[ 7],

[ 8]]], dtype=torch.int32)

tensor([[[ 9, 10, 11]]], dtype=torch.int32)

tensor([[[ 9]]], dtype=torch.int32)

如果不要collate_fn的值,输出变成

tensor([[ 0, 1, 2],

[ 1, 2, 3],

[ 2, 3, 4]], dtype=torch.int32)

tensor([[ 0],

[ 1],

[ 2]], dtype=torch.int32)

tensor([[ 3, 4, 5],

[ 4, 5, 6],

[ 5, 6, 7]], dtype=torch.int32)

tensor([[ 3],

[ 4],

[ 5]], dtype=torch.int32)

tensor([[ 6, 7, 8],

[ 7, 8, 9],

[ 8, 9, 10]], dtype=torch.int32)

tensor([[ 6],

[ 7],

[ 8]], dtype=torch.int32)

tensor([[ 9, 10, 11]], dtype=torch.int32)

tensor([[ 9]], dtype=torch.int32)

所以collate_fn就是使结果多一维。

看看collate_fn的值是什么意思。我们把它改为如下

collate_fn=lambda x:x

并输出

for i in loader:

print(i)

得到结果

[(tensor([ 0, 1, 2], dtype=torch.int32), tensor([ 0], dtype=torch.int32)), (tensor([ 1, 2, 3], dtype=torch.int32), tensor([ 1], dtype=torch.int32)), (tensor([ 2, 3, 4], dtype=torch.int32), tensor([ 2], dtype=torch.int32))]

[(tensor([ 3, 4, 5], dtype=torch.int32), tensor([ 3], dtype=torch.int32)), (tensor([ 4, 5, 6], dtype=torch.int32), tensor([ 4], dtype=torch.int32)), (tensor([ 5, 6, 7], dtype=torch.int32), tensor([ 5], dtype=torch.int32))]

[(tensor([ 6, 7, 8], dtype=torch.int32), tensor([ 6], dtype=torch.int32)), (tensor([ 7, 8, 9], dtype=torch.int32), tensor([ 7], dtype=torch.int32)), (tensor([ 8, 9, 10], dtype=torch.int32), tensor([ 8], dtype=torch.int32))]

[(tensor([ 9, 10, 11], dtype=torch.int32), tensor([ 9], dtype=torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

collate_fn=lambda x:(

torch.cat(

[x[i][j].unsqueeze(0) for i in range(len(x))], 0

).unsqueeze(0) for j in range(len(x[0]))

)

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成。

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

python中fn的用法_Pytorch技巧:DataLoader的collate_fn参数使用详解相关推荐

  1. python中sleep函数用法_sleep函数函数介绍与使用方法详解

    在一些竞猜的网站中,如果我们需要做一个定时执行的功能,比如有一道题,在十秒之内要完成,否则显示"您已超时",如果完成,则跳转到下一道题上面,而这中间有一个十秒的停顿,这样的功能是怎 ...

  2. 站长在线Python精讲:在Python中使用split()方法分割、使用join()方法合并字符串详解

    欢迎你来到站长在线的站长学堂学习Python知识,本文学习的是<在Python中使用split()方法分割.使用join()方法合并字符串详解>.本知识点主要内容有:在Python中使用s ...

  3. python中continue语句的作用_Pythoncontinue语句有什么作用?详解Pythoncontinue语句的用法...

    本文主要介绍python语句,Python continue 语句跳出本次循环,而break跳出整个循环.continue 语句用来告诉Python跳过当前循环的剩余语句,然后继续进行下一轮循环.co ...

  4. python中str是什么_python的str()字符串类型的方法详解

    字符串一旦创建,不可修改,一旦修改或者拼接,都会造成重新生成字符串,因为内存存数据是一个挨着一个存的,如果增加一个字符串的话,之前的老位置只有一个地方,不够,这是原理性的东西,在其他语言里面也一样 7 ...

  5. python中int转换为时间戳_python日期和时间戳互相转化操作详解

    Python中日期格式化是非常常见的操作,Python 中能用很多方式处理日期和时间,转换日期格式是一个常见的功能.Python 提供了一个 time 和 calendar 模块可以用于格式化日期和时 ...

  6. python中true是什么意思_Python解惑之True和False详解

    前言 众所周知在Python 中常用的数据类型bool(布尔)类型的实例对象(值)就两个,真和假,分别用True和False表示.在if 条件判断和while 语句中经常用到,不过在Python2.x ...

  7. python中pass语句的作用是_Python pass语句以及作用详解

    在具体开发设计中,有时大家会先构建起程序流程的总体逻辑结构,可是临时不去完成一些细节,只是在这种地区加一些注释,层面之后再加上编码,请看下面的事例: age = int( input("输入 ...

  8. python中4j什么意思_Python学习:4.数据类型以及运算符详解

    运算符 一.算数运算: 二.比较运算: 三.赋值运算 四.逻辑运算 五.成员运算 基本数据类型 一.Number(数字) Python3中支持int.float.bool.complex. 使用内置的 ...

  9. python中的complex是什么意思_Python 内置函数complex详解,pythoncomplex

    Python 内置函数complex详解,pythoncomplex 英文文档: class complex([real[, imag]]) Return a complex number with ...

最新文章

  1. 信号量的实现和应用实验报告_Java高级编程基础:原子信号量操作实现组线程执行管理...
  2. 结构体转char[]
  3. 不要再代码里频繁的new和delete
  4. pdf.js浏览中文pdf乱码的问题解决
  5. LaTeX的安装教程及问题记录
  6. Java基础 选择语句,循环结构数组
  7. python 生成式 生成器
  8. 微信小程序获取access_token报错errcode: 40125,errmsg: invalid appsecret
  9. git强制拉取最新代码
  10. 手把手系列--STM32H750移植FreeRTOS(二)--优化编译速度
  11. 数据挖掘导论实验报告01
  12. CodeForces - 106C Buns (多重背包二进制优化)
  13. 新闻管理系统(四)封装news表相关
  14. 数据结构:新冠病毒检测
  15. Python基础语法学习6
  16. 异步下载小说《诡秘之主》
  17. REVV Racing 联手 SuperPlastic,为您带来 Chunder 迷宫锦标赛
  18. 手握国企offer,33岁程序员不按常理出牌,网友炸了!
  19. antdpro使用AbortController取消请求
  20. 解决WARN: Establishing SSL connection without server‘s identity verification is not recommended警告问题~

热门文章

  1. java web中炸包,Javaweb出来炸到---HTML
  2. python将MP3转wave转成numpy
  3. C#中各种数据类型转换的方法的类
  4. 调查:台湾上班族讨厌5种年会状况 最怕老板致词长
  5. 【FastDev4Android框架开发】RecyclerView完全解析之下拉刷新与上拉加载SwipeRefreshLayout(三十一)...
  6. JS~字符串长度判断,超出进行自动截取(支持中文)
  7. Tomcat常见问题 (配置)及解决方法
  8. 基于JavaEE实现网上拍卖系统
  9. 大数据笔记2019.5.7
  10. 爬取jd商城手机类商品图片