Pytorch各种取样器sample
测试了pytorch的三种取样器用法。
一:概念
Sample:
取样器是在某一个数据集合上,按照某种策略进行取样。常见的策略包括顺序取样,随机取样(个样本等概率),随机取样(赋予个样本不同的概率)。以上三个策略都有放回和不放回两种方式。
TensorDataset:
对多个数据列表进行简单包装。就是用一个更大的list将多个不同类型的list数据进行简单包装。代码如下:
class TensorDataset(Dataset):r"""Dataset wrapping tensors.Each sample will be retrieved by indexing tensors along the first dimension.Arguments:*tensors (Tensor): tensors that have the same size of the first dimension."""def __init__(self, *tensors):assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in self.tensors)def __len__(self):return self.tensors[0].size(0)
二参数
1.SequentialSampler()
顺序采样,只有一个参数dataset。返回以一个与数据集等长的迭代器。
2.RandomSampler()
- data_source (Dataset) – dataset to sample from
- replacement (bool) – samples are drawn with replacement if True, default=``False`` 放回/不放回采样
- num_samples (python:int) – number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when replacement is True. 采样的次数,采集几次即迭代器中有几个样本。
几个测试例子:
data1 = torch.tensor(list(range(20,50)),dtype=torch.long)
sample1 = RandomSampler(data1,replacement=False)
for i,d in enumerate(sample1):print(i,d)
运行结果:
0 27
1 25
2 8
3 20
4 28
5 12
6 26
7 18
8 13
9 21
10 9
11 22
12 17
13 6
14 0
15 7
16 14
17 24
18 10
19 19
20 2
21 29
22 16
23 5
24 3
25 11
26 1
27 4
28 15
29 23
如果采样不放回的采样,不能指定num_samples。也就是num_samples必须等于默认值len(dataset)。随机不放回的取完集合。
3. WeightedRandomSampler
- weights (sequence) – a sequence of weights, not necessary summing up to one赋予每个样本权重。代表取到该样例的概率。数据不均衡时可以用来控制不同类别样本的采样权重
- num_samples (python:int) – number of samples to draw
- replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
三 使用方式
1.返回值。
注意以上每个取样器返回的都是样本在dataset中的索引,并不是样本本身。
看上面的例子,dataset是数值为30-50的列表。而遍历sample时,返回的是0-30的数值,代表采样样例的索引
2.取样器。
每次执行遍历取样器,取样器就会按照设定的方式进行一次取样。
data = torch.tensor(list(range(20)),dtype=torch.long)
sample = WeightedRandomSampler(list(range(20)),num_samples=10,replacement=False)for _ in range(5):print("******************************************")for i,d in enumerate(sample):print(i,d)
误解:samle是取样一次后存的结果,是一个值固定的迭代器。每次遍历的值一样。
实际结果:
******************************************
0 18
1 9
2 14
3 17
4 15
5 7
6 3
7 16
8 11
9 10
******************************************
0 16
1 9
2 18
3 8
4 4
5 19
6 17
7 11
8 15
9 10
******************************************
0 16
1 5
2 17
3 15
4 10
5 13
6 11
7 18
8 14
9 8
******************************************
0 17
1 18
2 16
3 13
4 3
5 11
6 19
7 14
8 12
9 9
******************************************
0 5
1 16
2 19
3 9
4 10
5 7
6 11
7 12
8 13
9 4
多次遍历sample,每次的值各不相同。也就是说每一次遍历都进行了一次取样。
3.与DataLoader一起使用
data = torch.tensor(list(range(20)),dtype=torch.long)
sample = WeightedRandomSampler(list(range(20)),num_samples=10,replacement=False)
daloloader = torch.utils.data.DataLoader(data,batch_size=2,sampler=sample)
for i,d in enumerate(daloloader):print(i,d)
结果:
0 tensor([9, 6])
1 tensor([18, 16])
2 tensor([ 7, 15])
3 tensor([13, 12])
4 tensor([ 5, 11])
这里执行的流程是。
首先sample在数据集dataset上进行取样。dataloader按batch_size的大小每次读取批量的数据(采样后的数据)。
例如:数据集合一共有20个样例,首先用取样器取出10个样例。Dataloader会在取样后的数据每次读取批量的样例。如果batch_size是2。那么dataloader的len是5。
Pytorch各种取样器sample相关推荐
- python sample函数取样_Pytorch各种取样器sample
测试了pytorch的三种取样器用法. 一:概念 Sample: 取样器是在某一个数据集合上,按照某种策略进行取样.常见的策略包括顺序取样,随机取样(个样本等概率),随机取样(赋予个样本不同的概率). ...
- 39_上下采样、MaxPool2d、AvgPool2d、ReLU案例、二维最大池化层和平均池化层、填充和步幅、多通道
1.34.PyTorch Down/up sample (pytorch上下采样) 1.34.1. 首先介绍下采样 1.34.1.1. MaxPool2d案例 1.34.1.2. AvgPool2d案 ...
- 【JMeter】各种逻辑控制器(Logic Controller)
文章目录 一.JMeter 逻辑控制器 二.逻辑控制器分类 1.简单控制器(Simple Controller) 2.循环控制器(Loop Controller) 3.仅一次控制器(Once Only ...
- Pytorch中的多项分布multinomial.Multinomial().sample()解析
在读<动手学深度学习 Pytorch>,疑惑于: fair_probs = torch.ones([6]) / 6 multinomial.Multinomial(1, fair_prob ...
- 【P25】JMeter 取样器超时(Sample Timeout)
文章目录 一.取样器超时(Sample Timeout)参数说明 二.准备工作 三.测试计划设计 一.取样器超时(Sample Timeout)参数说明 可以对采器设置最大超时时间 右键 >&g ...
- pytorch默认初始化_“最全PyTorch分布式教程”来了!
前言 本文对使用pytorch进行分布式训练(单机多卡)的过程进行了详细的介绍,附加实际代码,希望可以给正在看的你提供帮助.本文分三个部分展开,分别是: 先验知识 使用过程框架 代码解析 若想学习分布 ...
- pytorch常用代码
20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...
- dataframe sample 采样,抽样
20220324 https://blog.csdn.net/DSTJWJW/article/details/90667570 不重复随机抽样 20211223 # 读取数据集 test_data_a ...
- 深度学总结:skip-gram pytorch实现
文章目录 skip-gram pytorch 朴素实现 网络结构 训练过程:使用nn.NLLLoss() batch的准备,为unsupervised,准备数据获取(center,contex)的pa ...
最新文章
- 开发者,只有被裁,没有退休
- 通过神经图稳定对脑机接口的即插即用控制,四肢瘫痪患者可以轻松控制电脑光标...
- 解决mysql 1040错误Too many connections的方法
- 中国银屑病患者中银屑病关节炎的患病率和特征
- Struts2和Struts1的不同
- tomcat(6)生命周期
- php 将颜色透明度,css中如何使颜色透明度
- react入门(1)之阮一峰react教程
- CMD命令大全(已更新)
- 麦克纳姆轮平台坐标系说明
- OpenGL纹理贴图流程
- 范文杰 201421410010 作业2
- 自己动手实现神经网络分词模型
- 【转】python eval
- u盘linux运行速度慢,linux准确测量U盘读写速度
- 魔方解法 -- Rubic cube
- window11离线安装android子系统步骤
- 图形界面介绍Create Route Blockage
- 专注B2B跨境支付的背后,XTransfer的风控基础设施是如何炼成的?
- 安卓系统怎么连接服务器数据库,安卓端如何与服务器端数据库连接
热门文章
- lol澳洲服务器如何注册账号,LOL手游澳服怎么注册 云顶之弈手游澳服安装注册方法[多图]...
- 固定表头、打开excel for Mac后自动跳转到当前日期所在列并高亮显示
- 【新手教程】第一课:寻券记就什么
- 349两个数组的交集(遗留问题)
- 小学同学的儿子都开始上小学了,而我还在。。
- BricsCAD 19 for Mac(CAD建模软件)
- 计算机桌面死机的原因是,如果屏幕冻结,该怎么办?导致计算机死机的常见原因和解决方案....
- android向系统日历添加日程事件(实现闹铃效果,且在app被杀仍能完成)
- 苹果手机最傻Ⅹ的地方:
- 无剑100SOCwujian100挂UART外设之③硬件挂UART