图像分类数据集

半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签

class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import Sampler
import itertools
import numpy as npdef load_data(path, args, NO_LABEL=-1):if args.dataset == 'cifar10':mean = [x / 255 for x in [125.3, 123.0, 113.9]]std = [x / 255 for x in [63.0, 62.1, 66.7]]elif args.dataset == 'cifar100':mean = [x / 255 for x in [129.3, 124.1, 112.4]]std = [x / 255 for x in [68.2, 65.4, 70.4]]elif args.dataset == 'svhn':mean = [x / 255 for x in [127.5, 127.5, 127.5]]std = [x / 255 for x in [127.5, 127.5, 127.5]]elif args.dataset == 'mnist':mean = (0.5, )std = (0.5, )elif args.dataset == 'stl10':assert False, 'Do not finish stl10 code'elif args.dataset == 'imagenet':assert False, 'Do not finish imagenet code'else:assert False, "Unknow dataset : {}".format(args.dataset)if args.dataset == 'svhn':train_transform = transforms.Compose([transforms.RandomCrop(32, padding=2),transforms.ToTensor(),transforms.Normalize(mean, std)])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])elif args.dataset == 'mnist':train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])else:train_transform = TransformTwice(transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=2),transforms.ToTensor(),transforms.Normalize(mean, std)]))test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])if args.dataset == 'cifar10':train_data = datasets.CIFAR10(path, train=True, transform=train_transform, download=True)test_data = datasets.CIFAR10(path, train=False, transform=test_transform, download=True)num_classes = 10elif args.dataset == 'cifar100':train_data = datasets.CIFAR100(path, train=True, transform=train_transform, download=True)test_data = datasets.CIFAR100(path, train=False, transform=test_transform, download=True)num_classes = 100elif args.dataset == 'svhn':train_data = datasets.SVHN(path, split='train', transform=train_transform, download=True)test_data = datasets.SVHN(path, split='test', transform=test_transform, download=True)num_classes = 10elif args.dataset == 'mnist':train_data = datasets.MNIST(path, train=True, transform=train_transform, download=True)test_data = datasets.MNIST(path, train=False, transform=test_transform, download=True)num_classes = 10elif args.dataset == 'stl10':train_data = datasets.STL10(path, split='train', transform=train_transform, download=True)test_data = datasets.STL10(path, split='test', transform=test_transform, download=True)num_classes = 10elif args.dataset == 'imagenet':assert False, 'Do not finish imagenet code'else:assert False, 'Do not support dataset : {}'.format(args.dataset)labeled_idxs, unlabeled_idxs = spilt_l_u(args.dataset, train_data, args.num_labels)# if args.labeled_batch_size:# batch_sampler = TwoStreamBatchSampler(#     unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)# else:#     assert False, "labeled batch size {}".format(args.labeled_batch_size)if args.dataset == 'svhn':train_data.labels = np.array(train_data.labels)train_data.labels[unlabeled_idxs] = NO_LABELelse:train_data.targets = np.array(train_data.targets)train_data.targets[unlabeled_idxs] = NO_LABELtrain_loader = DataLoader(train_data,batch_size=args.batch_size,shuffle=True,num_workers=args.workers,pin_memory=True,drop_last=True)eval_loader = DataLoader(test_data,batch_size=args.eval_batch_size,shuffle=False,num_workers=args.workers,  # Needs images twice as fastpin_memory=True,drop_last=False)return train_loader, eval_loaderdef spilt_l_u(dataset, train_data, num_labels, num_val=400, classes=10):if dataset == 'mnist':labels = train_data.targets.numpy()elif dataset == 'svhn':labels = train_data.labelselse:labels = train_data.targetsv = num_valn = int(num_labels / classes)(indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(classes)]))# Ensure uniform distribution of labelsnp.random.shuffle(indices)indices_train = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(classes)])indices_unlabelled = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[n:] for i in range(classes)])indices_train = torch.from_numpy(indices_train)indices_unlabelled = torch.from_numpy(indices_unlabelled)return indices_train, indices_unlabelledclass TransformTwice:def __init__(self, transform):self.transform = transformdef __call__(self, inp):out1 = self.transform(inp)out2 = self.transform(inp)return out1, out2class TwoStreamBatchSampler(Sampler):"""Labeled + unlabeled data in a batchIterate two sets of indicesAn 'epoch' is one iteration through the primary indices.During the epoch, the secondary indices are iterated throughas many times as needed."""def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):self.primary_indices = primary_indicesself.secondary_indices = secondary_indicesself.secondary_batch_size = secondary_batch_sizeself.primary_batch_size = batch_size - secondary_batch_sizeassert len(self.primary_indices) >= self.primary_batch_size > 0assert len(self.secondary_indices) >= self.secondary_batch_size > 0def __iter__(self):primary_iter = iterate_once(self.primary_indices)secondary_iter = iterate_eternally(self.secondary_indices)return (primary_batch + secondary_batchfor (primary_batch, secondary_batch)in  zip(grouper(primary_iter, self.primary_batch_size),grouper(secondary_iter, self.secondary_batch_size)))def __len__(self):return len(self.primary_indices) // self.primary_batch_sizedef iterate_once(iterable):return np.random.permutation(iterable)def iterate_eternally(indices):def infinite_shuffles():while True:yield np.random.permutation(indices)return itertools.chain.from_iterable(infinite_shuffles())def grouper(iterable, n):"Collect data into fixed-length chunks or blocks"# grouper('ABCDEFG', 3) --> ABC DEF"args = [iter(iterable)] * nreturn zip(*args)

参考资料

  1. https://blog.csdn.net/Z609834342/article/details/106863690

半监督学习之数据加载相关推荐

  1. azure云数据库_从Azure Databricks将数据加载到Azure SQL数据库

    azure云数据库 In this article, we will learn how we can load data into Azure SQL Database from Azure Dat ...

  2. 【深度学习-数据加载优化-训练速度提升一倍】

    1,介绍 数据加载 深度学习的训练,简单的说就是将数据切分成batch,丢入模型中,并计算loss训练.其中比较重要的一环是数据打batch部分(数据加载部分). 训练时间优化: 深度学习训练往往需要 ...

  3. 旷视MegEngine数据加载与处理

    旷视MegEngine数据加载与处理 在网络训练与测试中,数据的加载和预处理往往会耗费大量的精力. MegEngine 提供了一系列接口来规范化这些处理工作. 利用 Dataset 封装一个数据集 数 ...

  4. PyTorch数据加载处理

    PyTorch数据加载处理 PyTorch提供了许多工具来简化和希望数据加载,使代码更具可读性. 1.下载安装包 • scikit-image:用于图像的IO和变换 • pandas:用于更容易地进行 ...

  5. 分享下自己写的一个微信小程序请求远程数据加载到页面的代码

    分享下自己写的一个微信小程序请求远程数据加载到页面的代码 1  思路整理 就是页面加载完毕的时候  请求远程接口,然后把数据赋值给页面的变量 ,然后列表循环 2 js相关代码  我是改的 onload ...

  6. Pytorch中的数据加载

    Pytorch中的数据加载 1. 模型中使用数据加载器的目的 在前面的线性回归模型中,使用的数据很少,所以直接把全部数据放到模型中去使用. 但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的 ...

  7. c语言文件 加载内存吗,把文件中的数据加载到内存进行查找C语言实现.docx

    把文件中的数据加载到内存进行查找C语言实现 #define _CRT_SECURE_NO_WARNINGS#include#include#includechar **pp=NULL;void ini ...

  8. hive外部表改为内部表_3000字揭秘Greenplum的外部数据加载——外部表

    外部表是greenplum的一种数据表,它与普通表不同的地方是:外部表是用来访问存储在greenplum数据库之外的数据.如普通表一样,可使用SQL对外部表进行查询和插入操作.外部表主要用于Green ...

  9. python使用matplotlib, seaborn画图时候的数据加载

    写在前面的话 当我们使用python来画图的时候,我觉得最难的部分应该是数据加载.因为尽管官网的教程给出了怎么画出某个图片的示例,但是数据往往是随机产生的,这些数据和我们需要的数据往往是不符的.这个时 ...

最新文章

  1. 2016.1.20 dubbo启动之后机器ip有问题
  2. linux消息通信无法接收,进程间通信:消息队列有关问题:进程1接收不到进程2的消息...
  3. Android项目中出现的Plugin with id ‘kotlin-android‘ not found解决方法
  4. Altium designer—STM32F103ZET6最小系统PCB图
  5. django-登陆案例-分析篇1909
  6. 华为p20Android怎么解开,华为P20如何获得root权限来解决自启动手机应用程序的问题...
  7. gstat | 空间插值(四)——克里金插值之协同克里金和交叉验证
  8. Python:[-1]、[:-1]、[::-1]、[n::-1] 原理大详解(超全超仔细!)
  9. CS224n(2019):Assignment2 参考答案
  10. 玩转流量,天下无锅——IT运维人员的九阳神功(上)
  11. 【项目实例】使用C#开发纽曼USB来电通来电弹屏客户端小结
  12. DBeaver配置phoenix
  13. 微软再次荣获 Gartner 工业物联网平台魔力象限“领先者”称号
  14. 云服务器密码登录异常的解决办法
  15. java怎样生成文档_关于如何生成Java文档
  16. 4. 美赛建模总结-3-最佳巧克力蛋糕烤盘(连续)-示例模型
  17. 【C语言进阶】⑦字符串函数详解
  18. qq联合登陆失败 错误码100044
  19. php json字符串转为数组或对象
  20. MySQL数据库---子查询insert,update,delete语句中嵌套子查询

热门文章

  1. 知识储备—01-进程,线程,多线程相关总结
  2. java中 快捷键输入System.out.println();
  3. 【iBoard电子学堂】【iCore双核心板】资料光盘A盘更新,版本号为A6
  4. drawrect java_对于drawRect使用,谨慎使用!
  5. angularjs html编辑器,AngularJS集成wangeditor富文本编辑器
  6. 万稞pw80线切割编程软件_线切割机床的控制系统的现状
  7. oracle材料差异科目,ORACLE分科目统计每科前三名的学生的语句
  8. PHP判断标量,php中is_scalar如何判断变量是否是一个标量
  9. 数组排序思想———选择排序
  10. mysql持久连接_持久性连接,短连接和连接池