torch.utils.data.WeightedRandomSampler样本不均衡情况下带权重随机采样
关于WeightedRandomSampler的用法csdn上有一些很棒的博客。本文参考博客Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样的代码对WeightedRandomSampler做进一步的分析。
首先从对官网给出的注释做进一步解释:
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
Parameters:
weights (sequence) – a sequence of weights, not necessary summing up to one(样本数量的倒数,如猫狗图片的如果有10张和20张,weights可设置为[0.67, 0.33])
num_samples (int) – number of samples to draw(还是用上述例子,此处为30)
replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.(后续例子重点讲这个参数。即是否重复取出样本以确保采样的均衡)
generator (Generator) – Generator used in sampling.(不用管)
接下来上代码:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler# Create dummy data with class imbalance 99 to 1
class_counts = torch.tensor([10, 50, 60])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 10
data = torch.randn(numDataPoints, data_dim)
for i in range(data.shape[0]):data[i, 0] = i#把样本的第一个值赋值为样本的行号,这样输出这个行号就知道样本是哪一类了#输出: 0~9行为第0类 ;10~59为第1类 ; 60~129行为第2类target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),torch.ones(class_counts[1], dtype=torch.long),torch.ones(class_counts[2], dtype=torch.long) * 2))print('target train 0/1/2: {}/{}/{}'.format((target == 0).sum(), (target == 1).sum(), (target == 2).sum()))# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor([(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(train_dataset, batch_size=bs, num_workers=0, sampler=sampler)
# train_loader = DataLoader(
# train_dataset, batch_size=bs, num_workers=0, shuffle=True)# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):print("batch index {}, 0/1/2: {}/{}/{}".format(i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))x_n = [el[0].tolist() for el in x]print(sorted(x_n))
这个例子中,我们给第一类10个样本,第二类50个样本,第三类60个样本。同时我们给把每个样本的第0个值设为其所在tensor的行号,这样在Dataloader输出时就知道是输出的哪一类的样本了。
此时
replacement=True
上述代码运行结果如下:
target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 5/2/3
[1.0, 2.0, 4.0, 6.0, 8.0, 45.0, 45.0, 95.0, 101.0, 108.0]
batch index 1, 0/1/2: 1/5/4
[4.0, 30.0, 30.0, 39.0, 39.0, 54.0, 65.0, 67.0, 96.0, 97.0]
batch index 2, 0/1/2: 3/3/4
[4.0, 7.0, 9.0, 22.0, 43.0, 59.0, 90.0, 104.0, 111.0, 119.0]
batch index 3, 0/1/2: 4/3/3
[1.0, 3.0, 4.0, 8.0, 22.0, 49.0, 52.0, 64.0, 81.0, 89.0]
batch index 4, 0/1/2: 2/5/3
[3.0, 4.0, 10.0, 25.0, 42.0, 44.0, 49.0, 74.0, 100.0, 104.0]
batch index 5, 0/1/2: 1/5/4
[7.0, 22.0, 31.0, 32.0, 34.0, 37.0, 83.0, 113.0, 113.0, 115.0]
batch index 6, 0/1/2: 5/2/3
[2.0, 2.0, 3.0, 8.0, 9.0, 13.0, 20.0, 61.0, 75.0, 97.0]
batch index 7, 0/1/2: 3/3/4
[3.0, 7.0, 9.0, 31.0, 31.0, 38.0, 70.0, 71.0, 75.0, 91.0]
batch index 8, 0/1/2: 4/1/5
[2.0, 3.0, 4.0, 7.0, 23.0, 67.0, 71.0, 74.0, 95.0, 117.0]
batch index 9, 0/1/2: 3/3/4
[2.0, 3.0, 4.0, 17.0, 36.0, 56.0, 103.0, 104.0, 110.0, 115.0]
batch index 10, 0/1/2: 4/3/3
[3.0, 6.0, 7.0, 8.0, 18.0, 39.0, 43.0, 66.0, 95.0, 105.0]
batch index 11, 0/1/2: 2/1/7
[0.0, 2.0, 22.0, 77.0, 81.0, 100.0, 103.0, 103.0, 106.0, 111.0]
本例中样本数量差距比较悬殊,可以看到取出的每个batch样本数量尽可能接近均衡。同时对每一个batch而言,可能取出重复的样本,在不同的batch内,对于第一类,也被重复取出了。
接下来我们设置
replacement=False
运行结果如下:
target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 4/5/1
[5.0, 6.0, 7.0, 9.0, 16.0, 22.0, 28.0, 30.0, 35.0, 69.0]
batch index 1, 0/1/2: 2/4/4
[0.0, 1.0, 10.0, 20.0, 26.0, 38.0, 82.0, 90.0, 111.0, 119.0]
batch index 2, 0/1/2: 1/3/6
[4.0, 17.0, 19.0, 23.0, 60.0, 63.0, 64.0, 68.0, 79.0, 92.0]
batch index 3, 0/1/2: 2/2/6
[3.0, 8.0, 40.0, 48.0, 89.0, 96.0, 100.0, 102.0, 104.0, 115.0]
batch index 4, 0/1/2: 1/3/6
[2.0, 32.0, 37.0, 55.0, 84.0, 88.0, 94.0, 106.0, 109.0, 110.0]
batch index 5, 0/1/2: 0/5/5
[12.0, 25.0, 29.0, 33.0, 34.0, 72.0, 78.0, 98.0, 107.0, 113.0]
batch index 6, 0/1/2: 0/3/7
[42.0, 43.0, 49.0, 70.0, 83.0, 85.0, 93.0, 95.0, 97.0, 117.0]
batch index 7, 0/1/2: 0/5/5
[21.0, 27.0, 44.0, 57.0, 59.0, 67.0, 74.0, 80.0, 86.0, 114.0]
batch index 8, 0/1/2: 0/5/5
[15.0, 31.0, 47.0, 50.0, 53.0, 75.0, 77.0, 103.0, 105.0, 112.0]
batch index 9, 0/1/2: 0/4/6
[18.0, 45.0, 46.0, 54.0, 65.0, 71.0, 73.0, 76.0, 99.0, 108.0]
batch index 10, 0/1/2: 0/6/4
[11.0, 13.0, 14.0, 24.0, 51.0, 56.0, 61.0, 66.0, 87.0, 101.0]
batch index 11, 0/1/2: 0/5/5
[36.0, 39.0, 41.0, 52.0, 58.0, 62.0, 81.0, 91.0, 116.0, 118.0]
可以看到对同一个数据,只会取一次。无论在一个batch中还是在整个epoch中。由于不能重复取,在前几个batch中,样本还比较接近均衡,但是当第一类被取完了之后,就没有办法再取了。而第二类和第三类样本数量相差较小,因此在整体范围内接近均衡。
接下来我们不用WeightedRandomSampler,而是随机打乱样本,看看采样结果:
train_loader = DataLoader(train_dataset, batch_size=bs, num_workers=0, shuffle=True)
此时的采样结果如下:
target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 0/5/5
[11.0, 17.0, 34.0, 46.0, 51.0, 67.0, 70.0, 78.0, 109.0, 111.0]
batch index 1, 0/1/2: 0/6/4
[10.0, 13.0, 19.0, 29.0, 38.0, 53.0, 66.0, 85.0, 88.0, 97.0]
batch index 2, 0/1/2: 1/4/5
[2.0, 21.0, 30.0, 41.0, 48.0, 65.0, 73.0, 81.0, 102.0, 116.0]
batch index 3, 0/1/2: 2/3/5
[4.0, 5.0, 14.0, 26.0, 35.0, 64.0, 89.0, 96.0, 98.0, 115.0]
batch index 4, 0/1/2: 2/2/6
[0.0, 3.0, 16.0, 52.0, 82.0, 83.0, 92.0, 106.0, 110.0, 113.0]
batch index 5, 0/1/2: 2/4/4
[6.0, 8.0, 18.0, 40.0, 57.0, 59.0, 63.0, 101.0, 108.0, 112.0]
batch index 6, 0/1/2: 0/6/4
[12.0, 27.0, 28.0, 37.0, 39.0, 56.0, 75.0, 77.0, 87.0, 103.0]
batch index 7, 0/1/2: 0/2/8
[15.0, 25.0, 61.0, 62.0, 76.0, 93.0, 99.0, 100.0, 104.0, 119.0]
batch index 8, 0/1/2: 3/4/3
[1.0, 7.0, 9.0, 36.0, 43.0, 45.0, 55.0, 71.0, 80.0, 86.0]
batch index 9, 0/1/2: 0/4/6
[22.0, 24.0, 44.0, 50.0, 69.0, 91.0, 105.0, 114.0, 117.0, 118.0]
batch index 10, 0/1/2: 0/5/5
[23.0, 31.0, 32.0, 42.0, 47.0, 60.0, 68.0, 74.0, 79.0, 107.0]
batch index 11, 0/1/2: 0/5/5
[20.0, 33.0, 49.0, 54.0, 58.0, 72.0, 84.0, 90.0, 94.0, 95.0]
可以看到样本被随机取出,在一个batch中,相比于replacement=False样本的不均衡程度更大。
在实际训练过程中,各位小伙伴可以根据自己需求,灵活调整参数(liandan)。
注:该方法多用于分类问题。即一个训练样本对应一个标签。对于分割问题,一个样本中有很多标签,用该方法就不太方便。分割问题推荐给损失函数添加权重,如nn.CrossEntropyLoss(weight=weight)。
torch.utils.data.WeightedRandomSampler样本不均衡情况下带权重随机采样相关推荐
- torch.utils.data.WeightedRandomSampler采样
采样策略可分为以下情况: case0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样 case1: Over sampling 对类 ...
- pytorch源码解析2——数据处理torch.utils.data
迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...
- PyTorch 源码解读之 torch.utils.data:解析数据处理全流程
目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...
- PyTorch—torch.utils.data.DataLoader 数据加载类
文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...
- pytorch torch.utils.data.TensorDataset
应用 import torch import torch.utils.data as Datax = torch.linspace(1, 10, 10) y = torch.linspace(10, ...
- torch.utils.data.DataLoader 详解
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...
- 【Torch】Dataloader torch.utils.data.DataLoader全面详实概念理解
目录 1.torch.utils.data.DataLoader概念介绍 2.torch.utils.data.DataLoader参数介绍 3 案例体会 DataLoader:[batch_size ...
- 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式
文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...
- 5 torch.utils.data (Dataset,TensorDataset,DataLoader)
文章目录 一.DataLoader(数据预处理) 1.DataLoader :(构建可迭代的数据装载器) 2.输出:DataLoader 的输出包含:数据和标签 二.TensorDataset(数据预 ...
最新文章
- linux系统发送信号的系统调用是,linux系统编程之信号:信号发送函数sigqueue和信号安装函数sigaction...
- 逐飞 RT1064 库 GCC (VSCode) 移植踩坑
- php 子类名,php的继承方法获取子类名
- NB驱动建立MQTT连接和断开MQTT连接的代码实现
- Interview:算法岗位面试—10.15上午—上海某公司算法岗位(偏图像算法,制造行业)技术面试考点之AI算法与实际场景结合产生商业价值的头脑风暴
- 大咖聊数据,视频抢先看
- P4062 [Code+#1]Yazid 的新生舞会(区间绝对众数+分治/树状数组维护高维前缀和)
- TDD容易被忽略的五大前提
- vba 中sql like用法
- 【Android Developers Training】 8. 定义Action Bar风格
- 怎样调整input框背景颜色_还在用百度搜索PPT背景图?7个高大上的图片网站,个个都是高清免费无版权!...
- 专业的在线考试系统-快考题,支持自制题库/在线试卷答题
- 第一天:2个法则,你的第一桶金可以这么来
- 7.交易开拓者-公式进阶(一)
- 三月模拟题——炉石传说
- 文件或目录损坏且无法读取的解决办法
- kali WiFi密码破解分享
- 百菜不如白菜 娃娃菜更营养吗
- 安卓很抱歉已停止运行
- 小米路由器R4AC 小米路由器4A百兆版 原厂BootLoader和eeprom