采样策略可分为以下情况:

case0:Over sampling &Under sampling ,即对类别多的进行降采样,对类别少的进行重采样
case1: Over sampling 对类别少的进行重采样,采样后的每类样本数与最多的那一类一致
case2:Under sampling 对类别多的进行降采样,采样后的每类样本数与最少的那一类一值
# 计算权重概率代码
# lables是每张图像的类别list [0,0,1,1,0,1,1,0,0,1]classes, class_sample_count = np.unique(labels, return_counts=True)# classes = np.unique(labels)weights = np.zeros(len(labels))for c in classes:freq_c = np.sum(np.array(labels) == c)  # 该类的数量weights[np.array(labels) == c] = len(labels) / freq_c # 该类的概率# weights[np.array(labels) == c] = 1. / freq_c weights = list(weights)

使用代码

#  有点类似欠采样,接近1:1, 少的多采样,多的少采样
train_nums = len(train_dataset)  # 所有类别的数量# 过采样, 每类数量过采样到最大类别的数量上
class_sample_count = np.array(list(train_dataset.class_sample_count_dict.values()))  #
class_sample_count_max = class_sample_count.max()  # 类别中数量最多的数量
class_nums = len(list(train_dataset.class_sample_count_dict.keys()))  # 类别数
train_nums = int(class_sample_count_max * class_nums)  # 总的数量sampler_weights = train_dataset.weights
train_sampler = torch.utils.data.WeightedRandomSampler(weights=sampler_weights, num_samples=train_nums,replacement=True)# 使用train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),num_workers=args.workers, pin_memory=True, sampler=train_sampler)# 验证是否进行类别平衡label_nums = [0, 0]  # 二分类for batch_idx, (data, target) in enumerate(train_loader):for target_i in target:lable_id = target_i.item()label_nums[lable_id] += 1print("dddd", lable_id)
得到label_nums : [8137, 7966],近似为1:1 ,之前是[ 5865 10238],接近 1:1.7
https://www.cnblogs.com/huadongw/p/6159408.html

torch.utils.data.WeightedRandomSampler采样相关推荐

  1. torch.utils.data.WeightedRandomSampler样本不均衡情况下带权重随机采样

    关于WeightedRandomSampler的用法csdn上有一些很棒的博客.本文参考博客Pytorch样本比例不均衡时采用WeightedRandomSampler进行采样的代码对Weighted ...

  2. pytorch源码解析2——数据处理torch.utils.data

    迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...

  3. PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    目录 0 前言 1 Dataset 1.1 Map-style dataset 1.2 Iterable-style dataset 1.3 其他 dataset 2 Sampler 3 DataLo ...

  4. PyTorch—torch.utils.data.DataLoader 数据加载类

    文章目录 DataLoader(object)类: _DataLoaderIter(object)类 __next__函数 pin_memory_batch() _get_batch函数 _proce ...

  5. Pytorch 中的数据类型 torch.utils.data.DataLoader 参数详解

    DataLoader是PyTorch中的一种数据类型,它定义了如何读取数据方式.详情也可参考本博主的另一篇关于torch.utils.data.DataLoader(https://blog.csdn ...

  6. 2021.08.24学习内容torch.utils.data.DataLoader以及CUDA与GPU的关系

    pytorch数据加载: ①totchvision 的包,含有支持加载类似Imagenet,CIFAR10,MNIST 等公共数据集的数据加载模块 torchvision.datasets impor ...

  7. PyTorch 1.0 中文文档:torch.utils.data

    译者:BXuan694 class torch.utils.data.Dataset 表示数据集的抽象类. 所有用到的数据集都必须是其子类.这些子类都必须重写以下方法:__len__:定义了数据集的规 ...

  8. pytorch torch.utils.data.TensorDataset

    应用 import torch import torch.utils.data as Datax = torch.linspace(1, 10, 10) y = torch.linspace(10, ...

  9. pytorch torch.utils.data.Dataset

    应用 from torch.utils.data import DataLoader, Dataset import torchclass TensorDataset(Dataset):# Tenso ...

最新文章

  1. java中JUnit单元测试的使用方法
  2. 尝试修改LabelImg,将以对顶角画框改成以对角线相交点向四周画框
  3. php mdecrypt generic,mdecrypt_generic
  4. Shell 变量--shell教程
  5. 设备翻转时viewController调用的方法
  6. 微软2020开源回顾:止不住的挨骂,停不下的贡献
  7. python变量的创建过程(内存地址变化)
  8. linux python版本升级和系统更新_Linux 下升级python和安装pip
  9. c标签判断true false jsp_北京尚学堂卓越班252天[第042天]——Jsp
  10. SVN日常提交工作时需要注意的事项
  11. android点击按钮修改文本,如何在android中每3秒动态更改按钮文本?
  12. 数字频率系数测试软件,sia smaartlive7
  13. 数字电路技术基础-1-补码
  14. VMware下如何虚拟软盘启动
  15. 加速进化,浪潮存储正在梦想成真
  16. 洛谷P1007独木桥题解
  17. 一款英国折叠车如何在中国城市流行?
  18. GSM Channel Mode Modify和Channel Mode Modify Acknowledge信令
  19. PT_二维随机变量:正态分布的可加性/一维随机变量函数与正态分布
  20. 事件触发过程(事件流)

热门文章

  1. A Brief History of Humankind — 01 the cognitive revolution
  2. 判断用户flash是否安装了flash以及flash的版本
  3. 【布局优化】基于粒子群求解物流选址matlab源码
  4. TeamTalk部署详细过程(跳过各种坑)
  5. 大数据之Hadoop3.x 运行环境搭建(手把手搭建集群)
  6. ros python 控制手柄数据发布频率
  7. 如何快速判断一个文件是否为病毒
  8. 下载typora beta版本
  9. 制造业ERP怎么创新与转型(阿朱说)
  10. Clickhouse 踩坑之旅 ---- MergeTree不合并分区的问题