测试了pytorch的三种取样器用法。

一:概念

Sample:

取样器是在某一个数据集合上,按照某种策略进行取样。常见的策略包括顺序取样,随机取样(个样本等概率),随机取样(赋予个样本不同的概率)。以上三个策略都有放回和不放回两种方式。

TensorDataset:

对多个数据列表进行简单包装。就是用一个更大的list将多个不同类型的list数据进行简单包装。代码如下:

class TensorDataset(Dataset):r"""Dataset wrapping tensors.Each sample will be retrieved by indexing tensors along the first dimension.Arguments:*tensors (Tensor): tensors that have the same size of the first dimension."""def __init__(self, *tensors):assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in self.tensors)def __len__(self):return self.tensors[0].size(0)

二参数

1.SequentialSampler()

顺序采样,只有一个参数dataset。返回以一个与数据集等长的迭代器。

2.RandomSampler()

  • data_source (Dataset) – dataset to sample from
  • replacement (bool) – samples are drawn with replacement if True, default=``False`` 放回/不放回采样
  • num_samples  (python:int) – number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when replacement is True. 采样的次数,采集几次即迭代器中有几个样本。

几个测试例子:

data1 = torch.tensor(list(range(20,50)),dtype=torch.long)
sample1 = RandomSampler(data1,replacement=False)
for i,d in enumerate(sample1):print(i,d)

运行结果:

0 27
1 25
2 8
3 20
4 28
5 12
6 26
7 18
8 13
9 21
10 9
11 22
12 17
13 6
14 0
15 7
16 14
17 24
18 10
19 19
20 2
21 29
22 16
23 5
24 3
25 11
26 1
27 4
28 15
29 23

如果采样不放回的采样,不能指定num_samples。也就是num_samples必须等于默认值len(dataset)。随机不放回的取完集合。

3. WeightedRandomSampler

  • weights (sequence) – a sequence of weights, not necessary summing up to one赋予每个样本权重。代表取到该样例的概率。数据不均衡时可以用来控制不同类别样本的采样权重
  • num_samples (python:int) – number of samples to draw
  • replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

三 使用方式

1.返回值。

注意以上每个取样器返回的都是样本在dataset中的索引,并不是样本本身。

看上面的例子,dataset是数值为30-50的列表。而遍历sample时,返回的是0-30的数值,代表采样样例的索引

2.取样器。

每次执行遍历取样器,取样器就会按照设定的方式进行一次取样。

data = torch.tensor(list(range(20)),dtype=torch.long)
sample = WeightedRandomSampler(list(range(20)),num_samples=10,replacement=False)for _ in range(5):print("******************************************")for i,d in enumerate(sample):print(i,d)

误解:samle是取样一次后存的结果,是一个值固定的迭代器。每次遍历的值一样。

实际结果:

******************************************
0 18
1 9
2 14
3 17
4 15
5 7
6 3
7 16
8 11
9 10
******************************************
0 16
1 9
2 18
3 8
4 4
5 19
6 17
7 11
8 15
9 10
******************************************
0 16
1 5
2 17
3 15
4 10
5 13
6 11
7 18
8 14
9 8
******************************************
0 17
1 18
2 16
3 13
4 3
5 11
6 19
7 14
8 12
9 9
******************************************
0 5
1 16
2 19
3 9
4 10
5 7
6 11
7 12
8 13
9 4

多次遍历sample,每次的值各不相同。也就是说每一次遍历都进行了一次取样。

3.与DataLoader一起使用

data = torch.tensor(list(range(20)),dtype=torch.long)
sample = WeightedRandomSampler(list(range(20)),num_samples=10,replacement=False)
daloloader = torch.utils.data.DataLoader(data,batch_size=2,sampler=sample)
for i,d in enumerate(daloloader):print(i,d)

结果:

0 tensor([9, 6])
1 tensor([18, 16])
2 tensor([ 7, 15])
3 tensor([13, 12])
4 tensor([ 5, 11])

这里执行的流程是。

首先sample在数据集dataset上进行取样。dataloader按batch_size的大小每次读取批量的数据(采样后的数据)。

例如:数据集合一共有20个样例,首先用取样器取出10个样例。Dataloader会在取样后的数据每次读取批量的样例。如果batch_size是2。那么dataloader的len是5。

Pytorch各种取样器sample相关推荐

  1. python sample函数取样_Pytorch各种取样器sample

    测试了pytorch的三种取样器用法. 一:概念 Sample: 取样器是在某一个数据集合上,按照某种策略进行取样.常见的策略包括顺序取样,随机取样(个样本等概率),随机取样(赋予个样本不同的概率). ...

  2. 39_上下采样、MaxPool2d、AvgPool2d、ReLU案例、二维最大池化层和平均池化层、填充和步幅、多通道

    1.34.PyTorch Down/up sample (pytorch上下采样) 1.34.1. 首先介绍下采样 1.34.1.1. MaxPool2d案例 1.34.1.2. AvgPool2d案 ...

  3. 【JMeter】各种逻辑控制器(Logic Controller)

    文章目录 一.JMeter 逻辑控制器 二.逻辑控制器分类 1.简单控制器(Simple Controller) 2.循环控制器(Loop Controller) 3.仅一次控制器(Once Only ...

  4. Pytorch中的多项分布multinomial.Multinomial().sample()解析

    在读<动手学深度学习 Pytorch>,疑惑于: fair_probs = torch.ones([6]) / 6 multinomial.Multinomial(1, fair_prob ...

  5. 【P25】JMeter 取样器超时(Sample Timeout)

    文章目录 一.取样器超时(Sample Timeout)参数说明 二.准备工作 三.测试计划设计 一.取样器超时(Sample Timeout)参数说明 可以对采器设置最大超时时间 右键 >&g ...

  6. pytorch默认初始化_“最全PyTorch分布式教程”来了!

    前言 本文对使用pytorch进行分布式训练(单机多卡)的过程进行了详细的介绍,附加实际代码,希望可以给正在看的你提供帮助.本文分三个部分展开,分别是: 先验知识 使用过程框架 代码解析 若想学习分布 ...

  7. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  8. dataframe sample 采样,抽样

    20220324 https://blog.csdn.net/DSTJWJW/article/details/90667570 不重复随机抽样 20211223 # 读取数据集 test_data_a ...

  9. 深度学总结:skip-gram pytorch实现

    文章目录 skip-gram pytorch 朴素实现 网络结构 训练过程:使用nn.NLLLoss() batch的准备,为unsupervised,准备数据获取(center,contex)的pa ...

最新文章

  1. 开发者,只有被裁,没有退休
  2. 通过神经图稳定对脑机接口的即插即用控制,四肢瘫痪患者可以轻松控制电脑光标...
  3. 解决mysql 1040错误Too many connections的方法
  4. 中国银屑病患者中银屑病关节炎的患病率和特征
  5. Struts2和Struts1的不同
  6. tomcat(6)生命周期
  7. php 将颜色透明度,css中如何使颜色透明度
  8. react入门(1)之阮一峰react教程
  9. CMD命令大全(已更新)
  10. 麦克纳姆轮平台坐标系说明
  11. OpenGL纹理贴图流程
  12. 范文杰 201421410010 作业2
  13. 自己动手实现神经网络分词模型
  14. 【转】python eval
  15. u盘linux运行速度慢,linux准确测量U盘读写速度
  16. 魔方解法 -- Rubic cube
  17. window11离线安装android子系统步骤
  18. 图形界面介绍Create Route Blockage
  19. 专注B2B跨境支付的背后,XTransfer的风控基础设施是如何炼成的?
  20. 安卓系统怎么连接服务器数据库,安卓端如何与服务器端数据库连接

热门文章

  1. lol澳洲服务器如何注册账号,LOL手游澳服怎么注册 云顶之弈手游澳服安装注册方法[多图]...
  2. 固定表头、打开excel for Mac后自动跳转到当前日期所在列并高亮显示
  3. 【新手教程】第一课:寻券记就什么
  4. 349两个数组的交集(遗留问题)
  5. 小学同学的儿子都开始上小学了,而我还在。。
  6. BricsCAD 19 for Mac(CAD建模软件)
  7. 计算机桌面死机的原因是,如果屏幕冻结,该怎么办?导致计算机死机的常见原因和解决方案....
  8. android向系统日历添加日程事件(实现闹铃效果,且在app被杀仍能完成)
  9. 苹果手机最傻Ⅹ的地方:
  10. 无剑100SOCwujian100挂UART外设之③硬件挂UART