
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
weights (sequence) – a sequence of weights, not necessary summing up to one(样本数量的倒数,如猫狗图片的如果有10张和20张,weights可设置为[0.67, 0.33])
num_samples (int) – number of samples to draw(还是用上述例子,此处为30)
replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.(后续例子重点讲这个参数。即是否重复取出样本以确保采样的均衡)
generator (Generator) – Generator used in sampling.(不用管)


import torch
from    torch.utils.data import DataLoader
from    torch.utils.data import WeightedRandomSampler# Create dummy data with class imbalance 99 to 1
class_counts = torch.tensor([10, 50, 60])
numDataPoints = class_counts.sum()
data_dim = 5
bs = 10
data = torch.randn(numDataPoints, data_dim)
for i in range(data.shape[0]):data[i, 0] = i#把样本的第一个值赋值为样本的行号,这样输出这个行号就知道样本是哪一类了#输出:  0~9行为第0类 ;10~59为第1类 ; 60~129行为第2类target = torch.cat((torch.zeros(class_counts[0], dtype=torch.long),torch.ones(class_counts[1], dtype=torch.long),torch.ones(class_counts[2], dtype=torch.long) * 2))print('target train 0/1/2: {}/{}/{}'.format((target == 0).sum(), (target == 1).sum(), (target == 2).sum()))# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor([(target == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target])# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
train_dataset = torch.utils.data.TensorDataset(data, target)
train_loader = DataLoader(train_dataset, batch_size=bs, num_workers=0, sampler=sampler)
# train_loader = DataLoader(
#     train_dataset, batch_size=bs, num_workers=0, shuffle=True)# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):print("batch index {}, 0/1/2: {}/{}/{}".format(i, (y == 0).sum(), (y == 1).sum(), (y == 2).sum()))x_n = [el[0].tolist() for el in x]print(sorted(x_n))




target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 5/2/3
[1.0, 2.0, 4.0, 6.0, 8.0, 45.0, 45.0, 95.0, 101.0, 108.0]
batch index 1, 0/1/2: 1/5/4
[4.0, 30.0, 30.0, 39.0, 39.0, 54.0, 65.0, 67.0, 96.0, 97.0]
batch index 2, 0/1/2: 3/3/4
[4.0, 7.0, 9.0, 22.0, 43.0, 59.0, 90.0, 104.0, 111.0, 119.0]
batch index 3, 0/1/2: 4/3/3
[1.0, 3.0, 4.0, 8.0, 22.0, 49.0, 52.0, 64.0, 81.0, 89.0]
batch index 4, 0/1/2: 2/5/3
[3.0, 4.0, 10.0, 25.0, 42.0, 44.0, 49.0, 74.0, 100.0, 104.0]
batch index 5, 0/1/2: 1/5/4
[7.0, 22.0, 31.0, 32.0, 34.0, 37.0, 83.0, 113.0, 113.0, 115.0]
batch index 6, 0/1/2: 5/2/3
[2.0, 2.0, 3.0, 8.0, 9.0, 13.0, 20.0, 61.0, 75.0, 97.0]
batch index 7, 0/1/2: 3/3/4
[3.0, 7.0, 9.0, 31.0, 31.0, 38.0, 70.0, 71.0, 75.0, 91.0]
batch index 8, 0/1/2: 4/1/5
[2.0, 3.0, 4.0, 7.0, 23.0, 67.0, 71.0, 74.0, 95.0, 117.0]
batch index 9, 0/1/2: 3/3/4
[2.0, 3.0, 4.0, 17.0, 36.0, 56.0, 103.0, 104.0, 110.0, 115.0]
batch index 10, 0/1/2: 4/3/3
[3.0, 6.0, 7.0, 8.0, 18.0, 39.0, 43.0, 66.0, 95.0, 105.0]
batch index 11, 0/1/2: 2/1/7
[0.0, 2.0, 22.0, 77.0, 81.0, 100.0, 103.0, 103.0, 106.0, 111.0]





target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 4/5/1
[5.0, 6.0, 7.0, 9.0, 16.0, 22.0, 28.0, 30.0, 35.0, 69.0]
batch index 1, 0/1/2: 2/4/4
[0.0, 1.0, 10.0, 20.0, 26.0, 38.0, 82.0, 90.0, 111.0, 119.0]
batch index 2, 0/1/2: 1/3/6
[4.0, 17.0, 19.0, 23.0, 60.0, 63.0, 64.0, 68.0, 79.0, 92.0]
batch index 3, 0/1/2: 2/2/6
[3.0, 8.0, 40.0, 48.0, 89.0, 96.0, 100.0, 102.0, 104.0, 115.0]
batch index 4, 0/1/2: 1/3/6
[2.0, 32.0, 37.0, 55.0, 84.0, 88.0, 94.0, 106.0, 109.0, 110.0]
batch index 5, 0/1/2: 0/5/5
[12.0, 25.0, 29.0, 33.0, 34.0, 72.0, 78.0, 98.0, 107.0, 113.0]
batch index 6, 0/1/2: 0/3/7
[42.0, 43.0, 49.0, 70.0, 83.0, 85.0, 93.0, 95.0, 97.0, 117.0]
batch index 7, 0/1/2: 0/5/5
[21.0, 27.0, 44.0, 57.0, 59.0, 67.0, 74.0, 80.0, 86.0, 114.0]
batch index 8, 0/1/2: 0/5/5
[15.0, 31.0, 47.0, 50.0, 53.0, 75.0, 77.0, 103.0, 105.0, 112.0]
batch index 9, 0/1/2: 0/4/6
[18.0, 45.0, 46.0, 54.0, 65.0, 71.0, 73.0, 76.0, 99.0, 108.0]
batch index 10, 0/1/2: 0/6/4
[11.0, 13.0, 14.0, 24.0, 51.0, 56.0, 61.0, 66.0, 87.0, 101.0]
batch index 11, 0/1/2: 0/5/5
[36.0, 39.0, 41.0, 52.0, 58.0, 62.0, 81.0, 91.0, 116.0, 118.0]



 train_loader = DataLoader(train_dataset, batch_size=bs, num_workers=0, shuffle=True)


target train 0/1/2: 10/50/60
batch index 0, 0/1/2: 0/5/5
[11.0, 17.0, 34.0, 46.0, 51.0, 67.0, 70.0, 78.0, 109.0, 111.0]
batch index 1, 0/1/2: 0/6/4
[10.0, 13.0, 19.0, 29.0, 38.0, 53.0, 66.0, 85.0, 88.0, 97.0]
batch index 2, 0/1/2: 1/4/5
[2.0, 21.0, 30.0, 41.0, 48.0, 65.0, 73.0, 81.0, 102.0, 116.0]
batch index 3, 0/1/2: 2/3/5
[4.0, 5.0, 14.0, 26.0, 35.0, 64.0, 89.0, 96.0, 98.0, 115.0]
batch index 4, 0/1/2: 2/2/6
[0.0, 3.0, 16.0, 52.0, 82.0, 83.0, 92.0, 106.0, 110.0, 113.0]
batch index 5, 0/1/2: 2/4/4
[6.0, 8.0, 18.0, 40.0, 57.0, 59.0, 63.0, 101.0, 108.0, 112.0]
batch index 6, 0/1/2: 0/6/4
[12.0, 27.0, 28.0, 37.0, 39.0, 56.0, 75.0, 77.0, 87.0, 103.0]
batch index 7, 0/1/2: 0/2/8
[15.0, 25.0, 61.0, 62.0, 76.0, 93.0, 99.0, 100.0, 104.0, 119.0]
batch index 8, 0/1/2: 3/4/3
[1.0, 7.0, 9.0, 36.0, 43.0, 45.0, 55.0, 71.0, 80.0, 86.0]
batch index 9, 0/1/2: 0/4/6
[22.0, 24.0, 44.0, 50.0, 69.0, 91.0, 105.0, 114.0, 117.0, 118.0]
batch index 10, 0/1/2: 0/5/5
[23.0, 31.0, 32.0, 42.0, 47.0, 60.0, 68.0, 74.0, 79.0, 107.0]
batch index 11, 0/1/2: 0/5/5
[20.0, 33.0, 49.0, 54.0, 58.0, 72.0, 84.0, 90.0, 94.0, 95.0]




  1. torch.utils.data.WeightedRandomSampler采样

    采样策略可分为以下情况: case0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样 case1: Over sampling 对类 ...

  2. pytorch源码解析2——数据处理torch.utils.data

    迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...

  3. PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...

  4. PyTorch—torch.utils.data.DataLoader 数据加载类

    文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...

  5. pytorch torch.utils.data.TensorDataset

    应用 import torch import torch.utils.data as Datax = torch.linspace(1, 10, 10) y = torch.linspace(10, ...

  6. torch.utils.data.DataLoader 详解

    class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, ...

  7. 【Torch】Dataloader torch.utils.data.DataLoader全面详实概念理解

    目录 1.torch.utils.data.DataLoader概念介绍 2.torch.utils.data.DataLoader参数介绍 3 案例体会 DataLoader:[batch_size ...

  8. 阅读源码-理解torch.utils.data、torch.utils.data.Dataset、torch.utils.data.DataLoader的工作方式

    文章目录 目标 Dataset DataLoader 应用 Dataset DataLoader 测试 知识点 Python splitlines()方法 python filter()函数 暂时先写 ...

  9. 5 torch.utils.data (Dataset,TensorDataset,DataLoader)

    文章目录 一.DataLoader(数据预处理) 1.DataLoader :(构建可迭代的数据装载器) 2.输出:DataLoader 的输出包含:数据和标签 二.TensorDataset(数据预 ...


  1. linux系统发送信号的系统调用是,linux系统编程之信号:信号发送函数sigqueue和信号安装函数sigaction...
  2. 逐飞 RT1064 库 GCC (VSCode) 移植踩坑
  3. php 子类名,php的继承方法获取子类名
  4. NB驱动建立MQTT连接和断开MQTT连接的代码实现
  5. Interview:算法岗位面试—10.15上午—上海某公司算法岗位(偏图像算法,制造行业)技术面试考点之AI算法与实际场景结合产生商业价值的头脑风暴
  6. 大咖聊数据,视频抢先看
  7. P4062 [Code+#1]Yazid 的新生舞会(区间绝对众数+分治/树状数组维护高维前缀和)
  8. TDD容易被忽略的五大前提
  9. vba 中sql like用法
  10. 【Android Developers Training】 8. 定义Action Bar风格
  11. 怎样调整input框背景颜色_还在用百度搜索PPT背景图?7个高大上的图片网站,个个都是高清免费无版权!...
  12. 专业的在线考试系统-快考题,支持自制题库/在线试卷答题
  13. 第一天:2个法则,你的第一桶金可以这么来
  14. 7.交易开拓者-公式进阶(一)
  15. 三月模拟题——炉石传说
  16. 文件或目录损坏且无法读取的解决办法
  17. kali WiFi密码破解分享
  18. 百菜不如白菜 娃娃菜更营养吗
  19. 安卓很抱歉已停止运行
  20. 小米路由器R4AC 小米路由器4A百兆版 原厂BootLoader和eeprom


  1. 笔记本高分辨软件兼容问题,字体太小或模糊
  2. c语言mallor使用方法,温州医学院仁济临床医学概论选择题整理
  3. 海思芯片HI35xx NNIE踩坑录
  4. PDF拆分技巧——如何在线拆分PDF
  5. 操作系统之短作业优先实现代码
  6. mysql题目练习的答案
  7. GHUB LUA脚本 压枪脚本 推荐APEX用
  8. 多路视频数据实时采集系统设计与实现
  9. sap采购申请自动转采购订单_采购订单_参考第三方销售生成的采购申请
  10. Android Google Face API 增强现实教程