关于WeightedRandomSampler的用法csdn上有一些很棒的博客。本文参考博客Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样的代码对WeightedRandomSampler做进一步的分析。
首先从对官网给出的注释做进一步解释:

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
Parameters:
weights (sequence) – a sequence of weights, not necessary summing up to one(样本数量的倒数,如猫狗图片的如果有10张和20张,weights可设置为[0.67, 0.33])
num_samples (int) – number of samples to draw(还是用上述例子,此处为30)
replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.(后续例子重点讲这个参数。即是否重复取出样本以确保采样的均衡)
generator (Generator) – Generator used in sampling.(不用管)

接下来上代码:

import torch
from    torch.utils.data import DataLoader
from    torch.utils.data import WeightedRandomSampler# Create dummy data with class imbalance 99 to 1
class_counts = torch.tensor([10, 50, 60])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 10
data = torch.randn(numDataPoints, data_dim)
for i in range(data.shape[0]):data[i, 0] = i#把样本的第一个值赋值为样本的行号,这样输出这个行号就知道样本是哪一类了#输出:  0~9行为第0类 ;10~59为第1类 ; 60~129行为第2类target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),torch.ones(class_counts[1], dtype=torch.long),torch.ones(class_counts[2], dtype=torch.long) * 2))print('target train 0/1/2: {}/{}/{}'.format((target == 0).sum(), (target == 1).sum(), (target == 2).sum()))# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor([(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(train_dataset, batch_size=bs, num_workers=0, sampler=sampler)
# train_loader = DataLoader(
#     train_dataset, batch_size=bs, num_workers=0, shuffle=True)# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):print("batch index {}, 0/1/2: {}/{}/{}".format(i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))x_n = [el[0].tolist() for el in x]print(sorted(x_n))

这个例子中,我们给第一类10个样本,第二类50个样本,第三类60个样本。同时我们给把每个样本的第0个值设为其所在tensor的行号,这样在Dataloader输出时就知道是输出的哪一类的样本了。
此时

replacement=True

上述代码运行结果如下:

target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 5/2/3
[1.0, 2.0, 4.0, 6.0, 8.0, 45.0, 45.0, 95.0, 101.0, 108.0]
batch index 1, 0/1/2: 1/5/4
[4.0, 30.0, 30.0, 39.0, 39.0, 54.0, 65.0, 67.0, 96.0, 97.0]
batch index 2, 0/1/2: 3/3/4
[4.0, 7.0, 9.0, 22.0, 43.0, 59.0, 90.0, 104.0, 111.0, 119.0]
batch index 3, 0/1/2: 4/3/3
[1.0, 3.0, 4.0, 8.0, 22.0, 49.0, 52.0, 64.0, 81.0, 89.0]
batch index 4, 0/1/2: 2/5/3
[3.0, 4.0, 10.0, 25.0, 42.0, 44.0, 49.0, 74.0, 100.0, 104.0]
batch index 5, 0/1/2: 1/5/4
[7.0, 22.0, 31.0, 32.0, 34.0, 37.0, 83.0, 113.0, 113.0, 115.0]
batch index 6, 0/1/2: 5/2/3
[2.0, 2.0, 3.0, 8.0, 9.0, 13.0, 20.0, 61.0, 75.0, 97.0]
batch index 7, 0/1/2: 3/3/4
[3.0, 7.0, 9.0, 31.0, 31.0, 38.0, 70.0, 71.0, 75.0, 91.0]
batch index 8, 0/1/2: 4/1/5
[2.0, 3.0, 4.0, 7.0, 23.0, 67.0, 71.0, 74.0, 95.0, 117.0]
batch index 9, 0/1/2: 3/3/4
[2.0, 3.0, 4.0, 17.0, 36.0, 56.0, 103.0, 104.0, 110.0, 115.0]
batch index 10, 0/1/2: 4/3/3
[3.0, 6.0, 7.0, 8.0, 18.0, 39.0, 43.0, 66.0, 95.0, 105.0]
batch index 11, 0/1/2: 2/1/7
[0.0, 2.0, 22.0, 77.0, 81.0, 100.0, 103.0, 103.0, 106.0, 111.0]

本例中样本数量差距比较悬殊,可以看到取出的每个batch样本数量尽可能接近均衡。同时对每一个batch而言,可能取出重复的样本,在不同的batch内,对于第一类,也被重复取出了。

接下来我们设置

replacement=False

运行结果如下:

target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 4/5/1
[5.0, 6.0, 7.0, 9.0, 16.0, 22.0, 28.0, 30.0, 35.0, 69.0]
batch index 1, 0/1/2: 2/4/4
[0.0, 1.0, 10.0, 20.0, 26.0, 38.0, 82.0, 90.0, 111.0, 119.0]
batch index 2, 0/1/2: 1/3/6
[4.0, 17.0, 19.0, 23.0, 60.0, 63.0, 64.0, 68.0, 79.0, 92.0]
batch index 3, 0/1/2: 2/2/6
[3.0, 8.0, 40.0, 48.0, 89.0, 96.0, 100.0, 102.0, 104.0, 115.0]
batch index 4, 0/1/2: 1/3/6
[2.0, 32.0, 37.0, 55.0, 84.0, 88.0, 94.0, 106.0, 109.0, 110.0]
batch index 5, 0/1/2: 0/5/5
[12.0, 25.0, 29.0, 33.0, 34.0, 72.0, 78.0, 98.0, 107.0, 113.0]
batch index 6, 0/1/2: 0/3/7
[42.0, 43.0, 49.0, 70.0, 83.0, 85.0, 93.0, 95.0, 97.0, 117.0]
batch index 7, 0/1/2: 0/5/5
[21.0, 27.0, 44.0, 57.0, 59.0, 67.0, 74.0, 80.0, 86.0, 114.0]
batch index 8, 0/1/2: 0/5/5
[15.0, 31.0, 47.0, 50.0, 53.0, 75.0, 77.0, 103.0, 105.0, 112.0]
batch index 9, 0/1/2: 0/4/6
[18.0, 45.0, 46.0, 54.0, 65.0, 71.0, 73.0, 76.0, 99.0, 108.0]
batch index 10, 0/1/2: 0/6/4
[11.0, 13.0, 14.0, 24.0, 51.0, 56.0, 61.0, 66.0, 87.0, 101.0]
batch index 11, 0/1/2: 0/5/5
[36.0, 39.0, 41.0, 52.0, 58.0, 62.0, 81.0, 91.0, 116.0, 118.0]

可以看到对同一个数据,只会取一次。无论在一个batch中还是在整个epoch中。由于不能重复取,在前几个batch中,样本还比较接近均衡,但是当第一类被取完了之后,就没有办法再取了。而第二类和第三类样本数量相差较小,因此在整体范围内接近均衡。

接下来我们不用WeightedRandomSampler,而是随机打乱样本,看看采样结果:

 train_loader = DataLoader(train_dataset, batch_size=bs, num_workers=0, shuffle=True)

此时的采样结果如下:

target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 0/5/5
[11.0, 17.0, 34.0, 46.0, 51.0, 67.0, 70.0, 78.0, 109.0, 111.0]
batch index 1, 0/1/2: 0/6/4
[10.0, 13.0, 19.0, 29.0, 38.0, 53.0, 66.0, 85.0, 88.0, 97.0]
batch index 2, 0/1/2: 1/4/5
[2.0, 21.0, 30.0, 41.0, 48.0, 65.0, 73.0, 81.0, 102.0, 116.0]
batch index 3, 0/1/2: 2/3/5
[4.0, 5.0, 14.0, 26.0, 35.0, 64.0, 89.0, 96.0, 98.0, 115.0]
batch index 4, 0/1/2: 2/2/6
[0.0, 3.0, 16.0, 52.0, 82.0, 83.0, 92.0, 106.0, 110.0, 113.0]
batch index 5, 0/1/2: 2/4/4
[6.0, 8.0, 18.0, 40.0, 57.0, 59.0, 63.0, 101.0, 108.0, 112.0]
batch index 6, 0/1/2: 0/6/4
[12.0, 27.0, 28.0, 37.0, 39.0, 56.0, 75.0, 77.0, 87.0, 103.0]
batch index 7, 0/1/2: 0/2/8
[15.0, 25.0, 61.0, 62.0, 76.0, 93.0, 99.0, 100.0, 104.0, 119.0]
batch index 8, 0/1/2: 3/4/3
[1.0, 7.0, 9.0, 36.0, 43.0, 45.0, 55.0, 71.0, 80.0, 86.0]
batch index 9, 0/1/2: 0/4/6
[22.0, 24.0, 44.0, 50.0, 69.0, 91.0, 105.0, 114.0, 117.0, 118.0]
batch index 10, 0/1/2: 0/5/5
[23.0, 31.0, 32.0, 42.0, 47.0, 60.0, 68.0, 74.0, 79.0, 107.0]
batch index 11, 0/1/2: 0/5/5
[20.0, 33.0, 49.0, 54.0, 58.0, 72.0, 84.0, 90.0, 94.0, 95.0]

可以看到样本被随机取出,在一个batch中,相比于replacement=False样本的不均衡程度更大。
在实际训练过程中,各位小伙伴可以根据自己需求,灵活调整参数(liandan)。

注:该方法多用于分类问题。即一个训练样本对应一个标签。对于分割问题,一个样本中有很多标签,用该方法就不太方便。分割问题推荐给损失函数添加权重,如nn.CrossEntropyLoss(weight=weight)。

torch.utils.data.WeightedRandomSampler样本不均衡情况下带权重随机采样相关推荐

  1. torch.utils.data.WeightedRandomSampler采样

    采样策略可分为以下情况: case0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样 case1: Over sampling 对类 ...

  2. pytorch源码解析2——数据处理torch.utils.data

    迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...

  3. PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...

  4. PyTorch—torch.utils.data.DataLoader 数据加载类

    文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...

  5. pytorch torch.utils.data.TensorDataset

    应用 import torch import torch.utils.data as Datax = torch.linspace(1, 10, 10) y = torch.linspace(10, ...

  6. torch.utils.data.DataLoader 详解

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...

  7. 【Torch】Dataloader torch.utils.data.DataLoader全面详实概念理解

    目录 1.torch.utils.data.DataLoader概念介绍 2.torch.utils.data.DataLoader参数介绍 3 案例体会 DataLoader:[batch_size ...

  8. 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式

    文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...

  9. 5 torch.utils.data (Dataset,TensorDataset,DataLoader)

    文章目录 一.DataLoader(数据预处理) 1.DataLoader :(构建可迭代的数据装载器) 2.输出:DataLoader 的输出包含:数据和标签 二.TensorDataset(数据预 ...

最新文章

  1. linux系统发送信号的系统调用是,linux系统编程之信号:信号发送函数sigqueue和信号安装函数sigaction...
  2. 逐飞 RT1064 库 GCC (VSCode) 移植踩坑
  3. php 子类名,php的继承方法获取子类名
  4. NB驱动建立MQTT连接和断开MQTT连接的代码实现
  5. Interview:算法岗位面试—10.15上午—上海某公司算法岗位(偏图像算法,制造行业)技术面试考点之AI算法与实际场景结合产生商业价值的头脑风暴
  6. 大咖聊数据,视频抢先看
  7. P4062 [Code+#1]Yazid 的新生舞会(区间绝对众数+分治/树状数组维护高维前缀和)
  8. TDD容易被忽略的五大前提
  9. vba 中sql like用法
  10. 【Android Developers Training】 8. 定义Action Bar风格
  11. 怎样调整input框背景颜色_还在用百度搜索PPT背景图?7个高大上的图片网站,个个都是高清免费无版权!...
  12. 专业的在线考试系统-快考题,支持自制题库/在线试卷答题
  13. 第一天:2个法则,你的第一桶金可以这么来
  14. 7.交易开拓者-公式进阶(一)
  15. 三月模拟题——炉石传说
  16. 文件或目录损坏且无法读取的解决办法
  17. kali WiFi密码破解分享
  18. 百菜不如白菜 娃娃菜更营养吗
  19. 安卓很抱歉已停止运行
  20. 小米路由器R4AC 小米路由器4A百兆版 原厂BootLoader和eeprom

热门文章

  1. 笔记本高分辨软件兼容问题,字体太小或模糊
  2. c语言mallor使用方法,温州医学院仁济临床医学概论选择题整理
  3. 海思芯片HI35xx NNIE踩坑录
  4. PDF拆分技巧——如何在线拆分PDF
  5. 操作系统之短作业优先实现代码
  6. mysql题目练习的答案
  7. GHUB LUA脚本 压枪脚本 推荐APEX用
  8. 多路视频数据实时采集系统设计与实现
  9. sap采购申请自动转采购订单_采购订单_参考第三方销售生成的采购申请
  10. Android Google Face API 增强现实教程