MNIST示例

定义

class MNIST(VisionDataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``MNIST/processed/training.pt``and  ``MNIST/processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""mirrors = ['http://yann.lecun.com/exdb/mnist/','https://ossci-datasets.s3.amazonaws.com/mnist/',]resources = [("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")]training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']@propertydef train_labels(self):warnings.warn("train_labels has been renamed targets")return self.targets@propertydef test_labels(self):warnings.warn("test_labels has been renamed targets")return self.targets@propertydef train_data(self):warnings.warn("train_data has been renamed data")return self.data@propertydef test_data(self):warnings.warn("test_data has been renamed data")return self.datadef __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(MNIST, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train  # training set or test setif self._check_legacy_exist():self.data, self.targets = self._load_legacy_data()returnif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')self.data, self.targets = self._load_data()def _check_legacy_exist(self):processed_folder_exists = os.path.exists(self.processed_folder)if not processed_folder_exists:return Falsereturn all(check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file))def _load_legacy_data(self):# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data# directly.data_file = self.training_file if self.train else self.test_filereturn torch.load(os.path.join(self.processed_folder, data_file))def _load_data(self):image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"data = read_image_file(os.path.join(self.raw_folder, image_file))label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"targets = read_label_file(os.path.join(self.raw_folder, label_file))return data, targetsdef __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], int(self.targets[index])# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)@propertydef raw_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'raw')@propertydef processed_folder(self) -> str:return os.path.join(self.root, self.__class__.__name__, 'processed')@propertydef class_to_idx(self) -> Dict[str, int]:return {_class: i for i, _class in enumerate(self.classes)}def _check_exists(self) -> bool:return all(check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))for url, _ in self.resources)def download(self) -> None:"""Download the MNIST data if it doesn't exist already."""if self._check_exists():returnos.makedirs(self.raw_folder, exist_ok=True)# download filesfor filename, md5 in self.resources:for mirror in self.mirrors:url = "{}{}".format(mirror, filename)try:print("Downloading {}".format(url))download_and_extract_archive(url, download_root=self.raw_folder,filename=filename,md5=md5)except URLError as error:print("Failed to download (trying next):\n{}".format(error))continuefinally:print()breakelse:raise RuntimeError("Error downloading {}".format(filename))def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")

FMNIST,KMNIST,QMNIST均可直接读取,在torchvision.datasets中

可通过下面的方式加载

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_train, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_test, shuffle=True)

CIFAR

class CIFAR10(VisionDataset):"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.Args:root (string): Root directory of dataset where directory``cifar-10-batches-py`` exists or will be saved to if download is set to True.train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""base_folder = 'cifar-10-batches-py'url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"filename = "cifar-10-python.tar.gz"tgz_md5 = 'c58f30108f718f92721af3b95e74349a'train_list = [['data_batch_1', 'c99cafc152244af753f735de768cd75f'],['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],['data_batch_4', '634d18415352ddfa80567beed471001a'],['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],]test_list = [['test_batch', '40351d587109b95175f43aff81a1287e'],]meta = {'filename': 'batches.meta','key': 'label_names','md5': '5ff9c542aee3614f3951f8cda6e48888',}def __init__(self,root: str,train: bool = True,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super(CIFAR10, self).__init__(root, transform=transform,target_transform=target_transform)self.train = train  # training set or test setif download:self.download()if not self._check_integrity():raise RuntimeError('Dataset not found or corrupted.' +' You can use download=True to download it')if self.train:downloaded_list = self.train_listelse:downloaded_list = self.test_listself.data: Any = []self.targets = []# now load the picked numpy arraysfor file_name, checksum in downloaded_list:file_path = os.path.join(self.root, self.base_folder, file_name)with open(file_path, 'rb') as f:entry = pickle.load(f, encoding='latin1')self.data.append(entry['data'])if 'labels' in entry:self.targets.extend(entry['labels'])else:self.targets.extend(entry['fine_labels'])self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWCself._load_meta()def _load_meta(self) -> None:path = os.path.join(self.root, self.base_folder, self.meta['filename'])if not check_integrity(path, self.meta['md5']):raise RuntimeError('Dataset metadata file not found or corrupted.' +' You can use download=True to download it')with open(path, 'rb') as infile:data = pickle.load(infile, encoding='latin1')self.classes = data[self.meta['key']]self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}def __getitem__(self, index: int) -> Tuple[Any, Any]:"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""img, target = self.data[index], self.targets[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img)if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self) -> int:return len(self.data)def _check_integrity(self) -> bool:root = self.rootfor fentry in (self.train_list + self.test_list):filename, md5 = fentry[0], fentry[1]fpath = os.path.join(root, self.base_folder, filename)if not check_integrity(fpath, md5):return Falsereturn Truedef download(self) -> None:if self._check_integrity():print('Files already downloaded and verified')returndownload_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)def extra_repr(self) -> str:return "Split: {}".format("Train" if self.train is True else "Test")

CIFAR100同理

预定义数据集

pytorch1.8以后其余已定义的数据集有

  • Caltech101
  • Caltech256 (文件名caltech.py)
  • STL10 (stl10.py)
  • SVHN (svhn.py)
  • CelebA (celeba.py)
  • INaturalist (inaturalist.py)
  • Omniglot (omniglot.py)
  • Places365 (places365.py)

需要自己下载完整数据集

  • LSUNClass
  • ImageNet

补充

Food101

from pathlib import Path
import json
from typing import Any, Tuple, Callable, Optional
import torch
import PIL.Imagefrom torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDatasetclass Food101(VisionDataset):"""`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.The Food-101 is a challenging data set of 101 food categories, with 101'000 images.For each class, 250 manually reviewed test images are provided as well as 750 training images.On purpose, the training images were not cleaned, and thus still contain some amount of noise.This comes mostly in the form of intense colors and sometimes wrong labels. All images wererescaled to have a maximum side length of 512 pixels.Args:root (string): Root directory of the dataset.split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformedversion. E.g, ``transforms.RandomCrop``.target_transform (callable, optional): A function/transform that takes in the target and transforms it.download (bool, optional): If True, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again. Default is False."""_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"_MD5 = "85eeb15f3717b99a5da872d97d918f87"def __init__(self,root: str,split: str = "train",transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super().__init__(root, transform=transform, target_transform=target_transform)self._split = verify_str_arg(split, "split", ("train", "test"))self._base_folder = Path(self.root) / "food-101"self._meta_folder = self._base_folder / "meta"self._images_folder = self._base_folder / "images"self.class_names_str = ['Apple pie', 'Baby back ribs', 'Baklava', 'Beef carpaccio', 'Beef tartare', 'Beet salad', 'Beignets', 'Bibimbap', 'Bread pudding', 'Breakfast burrito', 'Bruschetta', 'Caesar salad', 'Cannoli', 'Caprese salad', 'Carrot cake', 'Ceviche', 'Cheesecake', 'Cheese plate', 'Chicken curry', 'Chicken quesadilla', 'Chicken wings', 'Chocolate cake', 'Chocolate mousse', 'Churros', 'Clam chowder', 'Club sandwich', 'Crab cakes', 'Creme brulee', 'Croque madame', 'Cup cakes', 'Deviled eggs', 'Donuts', 'Dumplings', 'Edamame', 'Eggs benedict', 'Escargots', 'Falafel', 'Filet mignon', 'Fish and chips', 'Foie gras', 'French fries', 'French onion soup', 'French toast', 'Fried calamari', 'Fried rice', 'Frozen yogurt', 'Garlic bread', 'Gnocchi', 'Greek salad', 'Grilled cheese sandwich', 'Grilled salmon', 'Guacamole', 'Gyoza', 'Hamburger', 'Hot and sour soup', 'Hot dog', 'Huevos rancheros', 'Hummus', 'Ice cream', 'Lasagna', 'Lobster bisque', 'Lobster roll sandwich', 'Macaroni and cheese', 'Macarons', 'Miso soup', 'Mussels', 'Nachos', 'Omelette', 'Onion rings', 'Oysters', 'Pad thai', 'Paella', 'Pancakes', 'Panna cotta', 'Peking duck', 'Pho', 'Pizza', 'Pork chop', 'Poutine', 'Prime rib', 'Pulled pork sandwich', 'Ramen', 'Ravioli', 'Red velvet cake', 'Risotto', 'Samosa', 'Sashimi', 'Scallops', 'Seaweed salad', 'Shrimp and grits', 'Spaghetti bolognese', 'Spaghetti carbonara', 'Spring rolls', 'Steak', 'Strawberry shortcake', 'Sushi', 'Tacos', 'Takoyaki', 'Tiramisu', 'Tuna tartare', 'Waffles']if download:self._download()if not self._check_exists():raise RuntimeError("Dataset not found. You can use download=True to download it")self._labels = []self._image_files = []with open(self._meta_folder / f"{split}.json") as f:metadata = json.loads(f.read())self.classes = sorted(metadata.keys())self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))for class_label, im_rel_paths in metadata.items():self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)self._image_files += [self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths]def __len__(self) -> int:return len(self._image_files)def __getitem__(self, idx) -> Tuple[Any, Any]:image_file, label = self._image_files[idx], self._labels[idx]image = PIL.Image.open(image_file).convert("RGB")if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, labeldef extra_repr(self) -> str:return f"split={self._split}"def _check_exists(self) -> bool:return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))def _download(self) -> None:if self._check_exists():returndownload_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)def examine_count(counter, name = "train"):print(f"in the {name} set")for label in counter:print(label, counter[label])if __name__ == "__main__":label_names = []with open('debug/food101_labels.txt') as f:for name in f:label_names.append(name.strip())print(label_names)train_set = Food101(root = "/nobackup/dataset_myf", split = "train", download = True)test_set = Food101(root = "/nobackup/dataset_myf", split = "test")print(f"train set len {len(train_set)}")print(f"test set len {len(test_set)}")from collections import Countertrain_label_count = Counter(train_set._labels)test_label_count = Counter(test_set._labels)# examine_count(train_label_count, name = "train")# examine_count(test_label_count, name = "test")kwargs = {'num_workers': 4, 'pin_memory': True}train_loader = torch.utils.data.DataLoader(train_set ,batch_size=16, shuffle=True, **kwargs)val_loader = torch.utils.data.DataLoader(test_set,batch_size=16, shuffle=False, **kwargs)

Flower102

from pathlib import Path
from typing import Any, Tuple, Callable, Optional
import torch
import PIL.Imagefrom torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDatasetclass Flowers102(VisionDataset):"""`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset... warning::This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. Theflowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists ofbetween 40 and 258 images.The images have large scale, pose and light variations. In addition, there are categories thathave large variations within the category, and several very similar categories.Args:root (string): Root directory of the dataset.split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.transform (callable, optional): A function/transform that takes in an PIL image and returns atransformed version. E.g, ``transforms.RandomCrop``.target_transform (callable, optional): A function/transform that takes in the target and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"_file_dict = {  # filename, md5"image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),"label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),"setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),}_splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}# https://gist.github.com/JosephKJ/94c7728ed1a8e0cd87fe6a029769cde1label_names = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', 'hippeastrum ', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']def __init__(self,root: str,split: str = "train",transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:super().__init__(root, transform=transform, target_transform=target_transform)self._split = verify_str_arg(split, "split", ("train", "val", "test"))self._base_folder = Path(self.root) / "flowers-102"self._images_folder = self._base_folder / "jpg"if download:self.download()if not self._check_integrity():raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")from scipy.io import loadmatset_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)image_ids = set_ids[self._splits_map[self._split]].tolist()labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)image_id_to_label = dict(enumerate(labels["labels"].tolist(), 1))self._labels = []self._image_files = []for image_id in image_ids:self._labels.append(image_id_to_label[image_id])self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")self.class_names_str = self.label_namesdef __len__(self) -> int:return len(self._image_files)def __getitem__(self, idx) -> Tuple[Any, Any]:image_file, label = self._image_files[idx], self._labels[idx]image = PIL.Image.open(image_file).convert("RGB")if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, labeldef extra_repr(self) -> str:return f"split={self._split}"def _check_integrity(self):if not (self._images_folder.exists() and self._images_folder.is_dir()):return Falsefor id in ["label", "setid"]:filename, md5 = self._file_dict[id]if not check_integrity(str(self._base_folder / filename), md5):return Falsereturn Truedef download(self):if self._check_integrity():returndownload_and_extract_archive(f"{self._download_url_prefix}{self._file_dict['image'][0]}",str(self._base_folder),md5=self._file_dict["image"][1],)for id in ["label", "setid"]:filename, md5 = self._file_dict[id]download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)def examine_count(counter, name = "train"):print(f"in the {name} set")for label in counter:print(label, counter[label])if __name__ == "__main__":# label_names = []# with open('debug/flowers102_labels.txt') as f:#     for name in f:#         label_names.append(name.strip()[1:-1])# print(label_names)train_set = Flowers102(root = "/nobackup/dataset_myf", split = "train", download = True)val_set = Flowers102(root = "/nobackup/dataset_myf", split = "val")test_set = Flowers102(root = "/nobackup/dataset_myf", split = "test")from collections import Countertrain_label_count = Counter(train_set._labels)val_label_count = Counter(val_set._labels)test_label_count = Counter(test_set._labels)examine_count(train_label_count, name = "train")examine_count(val_label_count, name = "val")examine_count(test_label_count, name = "test")kwargs = {'num_workers': 4, 'pin_memory': True}train_loader = torch.utils.data.DataLoader(train_set ,batch_size=16, shuffle=True, **kwargs)val_loader = torch.utils.data.DataLoader(val_set,batch_size=16, shuffle=False, **kwargs)

Car196

import pathlib
from typing import Callable, Optional, Any, Tuplefrom PIL import Image
import torchfrom torchvision.datasets.utils import check_integrity,download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDatasetclass StanfordCars(VisionDataset):"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ DatasetThe Cars dataset contains 16,185 images of 196 classes of cars. The data issplit into 8,144 training images and 8,041 testing images, where each classhas been split roughly in a 50-50 split.. note::This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.Args:root (string): Root directory of datasetsplit (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If True, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""def __init__(self,root: str,split: str = "train",transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,download: bool = False,) -> None:try:import scipy.io as sioexcept ImportError:raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")super().__init__(root, transform=transform, target_transform=target_transform)self._split = verify_str_arg(split, "split", ("train", "test"))self._base_folder = pathlib.Path(root) / "stanford_cars"devkit = self._base_folder / "devkit"if self._split == "train":self._annotations_mat_path = devkit / "cars_train_annos.mat"self._images_base_path = self._base_folder / "cars_train"else:self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"self._images_base_path = self._base_folder / "cars_test"if download:self.download()if not self._check_exists():raise RuntimeError("Dataset not found. You can use download=True to download it")self._samples = [(str(self._images_base_path / annotation["fname"]),annotation["class"] - 1,  # Original target mapping  starts from 1, hence -1)for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]]self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}self.class_names_str = self.classesdef __len__(self) -> int:return len(self._samples)def __getitem__(self, idx: int) -> Tuple[Any, Any]:"""Returns pil_image and class_id for given index"""image_path, target = self._samples[idx]pil_image = Image.open(image_path).convert("RGB")if self.transform is not None:pil_image = self.transform(pil_image)if self.target_transform is not None:target = self.target_transform(target)return pil_image, targetdef download(self) -> None:if self._check_exists():returndownload_and_extract_archive(url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",download_root=str(self._base_folder),md5="c3b158d763b6e2245038c8ad08e45376",)if self._split == "train":download_and_extract_archive(url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",download_root=str(self._base_folder),md5="065e5b463ae28d29e77c1b4b166cfe61",)else:download_and_extract_archive(url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",download_root=str(self._base_folder),md5="4ce7ebf6a94d07f1952d94dd34c4d501",)download_url(url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",root=str(self._base_folder),md5="b0a2b23655a3edd16d84508592a98d10",)def _check_exists(self) -> bool:if not (self._base_folder / "devkit").is_dir():return Falsereturn self._annotations_mat_path.exists() and self._images_base_path.is_dir()def examine_count(counter, name = "train"):print(f"in the {name} set")for label in counter:print(label, counter[label])if __name__ == "__main__":train_set = StanfordCars(root = "/nobackup/dataset_myf", split = "train", download = True)test_set = StanfordCars(root = "/nobackup/dataset_myf", split = "test", download = True)print(f"train set len {len(train_set)}")print(f"test set len {len(test_set)}")from collections import Countertrain_label_count = Counter([label for img, label in train_set._samples])test_label_count = Counter([label for img, label in test_set._samples])examine_count(train_label_count, name = "train")examine_count(test_label_count, name = "test")kwargs = {'num_workers': 4, 'pin_memory': True}train_loader = torch.utils.data.DataLoader(train_set ,batch_size=16, shuffle=True, **kwargs)val_loader = torch.utils.data.DataLoader(test_set,batch_size=16, shuffle=False, **kwargs)

CUB200

import numpy as np
# 读取数据
import matplotlib.image
import os
from PIL import Image
from torchvision import transforms
import torchclass CUB():def __init__(self, root, is_train=True, data_len=None,transform=None, target_transform=None):self.root = rootself.is_train = is_trainself.transform = transformself.target_transform = target_transformimg_txt_file = open(os.path.join(self.root, 'images.txt'))label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))# 图片索引img_name_list = []for line in img_txt_file:# 最后一个字符为换行符img_name_list.append(line[:-1].split(' ')[-1])# 标签索引,每个对应的标签减1,标签值从0开始label_list = []for line in label_txt_file:label_list.append(int(line[:-1].split(' ')[-1]) - 1)# 设置训练集和测试集train_test_list = []for line in train_val_file:train_test_list.append(int(line[:-1].split(' ')[-1]))# zip压缩合并,将数据与标签(训练集还是测试集)对应压缩# zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,# 然后返回由这些元组组成的对象,这样做的好处是节约了不少的内存。# 我们可以使用 list() 转换来输出列表# 如果 i 为 1,那么设为训练集# 1为训练集,0为测试集# zip压缩合并,将数据与标签(训练集还是测试集)对应压缩# 如果 i 为 1,那么设为训练集train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]train_label_list = [x for i, x in zip(train_test_list, label_list) if i][:data_len]test_label_list = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]if self.is_train:# matplotlib.image.imread 图片读取出来为array类型,即numpy类型self.train_img = [matplotlib.image.imread(os.path.join(self.root, 'images', train_file)) for train_file intrain_file_list[:data_len]]# 读取训练集标签self.train_label = train_label_listif not self.is_train:self.test_img = [matplotlib.image.imread(os.path.join(self.root, 'images', test_file)) for test_file intest_file_list[:data_len]]self.test_label = test_label_list# 数据增强def __getitem__(self,index):# 训练集if self.is_train:img, target = self.train_img[index], self.train_label[index]# 测试集else:img, target = self.test_img[index], self.test_label[index]if len(img.shape) == 2:# 灰度图像转为三通道img = np.stack([img]*3,2)# 转为 RGB 类型img = Image.fromarray(img,mode='RGB')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):if self.is_train:return len(self.train_label)else:return len(self.test_label)if __name__ == '__main__':'''dataset = CUB(root='./CUB_200_2011')for data in dataset:print(data[0].size(),data[1])'''# 以pytorch中DataLoader的方式读取数据集transform_train = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomCrop(224, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),])dataset = CUB(root='../dataset/CUB_200_2011', is_train=True, transform=transform_train,)print(len(dataset))trainloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0,drop_last=True)print(len(trainloader))

Aircraft

import numpy as np
import os
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torchvision.datasets.utils import extract_archiveclass Aircraft(VisionDataset):"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.Args:root (string): Root directory of the dataset.train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.class_type (string, optional): choose from ('variant', 'family', 'manufacturer').transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'class_types = ('variant', 'family', 'manufacturer')splits = ('train', 'val', 'trainval', 'test')img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')def __init__(self, root, train=True, class_type='variant', transform=None,target_transform=None, download=False):super(Aircraft, self).__init__(root, transform=transform, target_transform=target_transform)split = 'trainval' if train else 'test'if split not in self.splits:raise ValueError('Split "{}" not found. Valid splits are: {}'.format(split, ', '.join(self.splits),))if class_type not in self.class_types:raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(class_type, ', '.join(self.class_types),))self.class_type = class_typeself.split = splitself.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data','images_%s_%s.txt' % (self.class_type, self.split))if download:self.download()(image_ids, targets, classes, class_to_idx) = self.find_classes()samples = self.make_dataset(image_ids, targets)self.loader = default_loaderself.samples = samplesself.classes = classesself.class_to_idx = class_to_idxdef __getitem__(self, index):path, target = self.samples[index]sample = self.loader(path)if self.transform is not None:sample = self.transform(sample)if self.target_transform is not None:target = self.target_transform(target)return sample, targetdef __len__(self):return len(self.samples)def _check_exists(self):return os.path.exists(os.path.join(self.root, self.img_folder)) and \os.path.exists(self.classes_file)def download(self):if self._check_exists():return# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gzprint('Downloading %s...' % self.url)tar_name = self.url.rpartition('/')[-1]download_url(self.url, root=self.root, filename=tar_name)tar_path = os.path.join(self.root, tar_name)print('Extracting %s...' % tar_path)extract_archive(tar_path)print('Done!')def find_classes(self):# read classes file, separating out image IDs and class namesimage_ids = []targets = []with open(self.classes_file, 'r') as f:for line in f:split_line = line.split(' ')image_ids.append(split_line[0])targets.append(' '.join(split_line[1:]))# index class namesclasses = np.unique(targets)class_to_idx = {classes[i]: i for i in range(len(classes))}targets = [class_to_idx[c] for c in targets]return image_ids, targets, classes, class_to_idxdef make_dataset(self, image_ids, targets):assert (len(image_ids) == len(targets))images = []for i in range(len(image_ids)):item = (os.path.join(self.root, self.img_folder,'%s.jpg' % image_ids[i]), targets[i])images.append(item)return imagesif __name__ == '__main__':train_dataset = Aircraft('./aircraft', train=True, download=False)test_dataset = Aircraft('./aircraft', train=False, download=False)

PermutedMNIST

class PermutedMNISTDataLoader(torchvision.datasets.MNIST):def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)self.train = trainself.num_data = 0if self.train:self.permuted_train_data = torch.stack([img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0for img in self.train_data])self.num_data = self.permuted_train_data.shape[0]else:self.permuted_test_data = torch.stack([img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0for img in self.test_data])self.num_data = self.permuted_test_data.shape[0]def __getitem__(self, index):if self.train:input, label = self.permuted_train_data[index], self.train_labels[index]else:input, label = self.permuted_test_data[index], self.test_labels[index]return input, labeldef getNumData(self):return self.num_databatch_size = 64
learning_rate = 1e-3
num_task = 10
criterion = torch.nn.CrossEntropyLoss()
cuda_available = False
if torch.cuda.is_available():cuda_available = Truedef permute_mnist():train_loader = {}test_loader = {}train_data_num = 0test_data_num = 0for i in range(num_task):shuffle_seed = np.arange(28*28)np.random.shuffle(shuffle_seed)train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)train_data_num += train_PMNIST_DataLoader.getNumData()test_data_num += test_PMNIST_DataLoader.getNumData()train_loader[i] = torch.utils.data.DataLoader(train_PMNIST_DataLoader,batch_size=batch_size)test_loader[i] = torch.utils.data.DataLoader(test_PMNIST_DataLoader,batch_size=batch_size)return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)train_loader, test_loader, train_data_num, test_data_num = permute_mnist()

TinyImageNet

import os
import os
import pandas as pd
import warnings
from torchvision.datasets import ImageFolder
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import extract_archive, check_integrity, download_url, verify_str_argclass TinyImageNet(VisionDataset):"""`tiny-imageNet <http://cs231n.stanford.edu/tiny-imagenet-200.zip>`_ Dataset.Args:root (string): Root directory of the dataset.split (string, optional): The dataset split, supports ``train``, or ``val``.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again."""base_folder = 'tiny-imagenet-200/'url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'filename = 'tiny-imagenet-200.zip'md5 = '90528d7ca1a48142e341f4ef8d21d0de'def __init__(self, root, split='train', transform=None, target_transform=None, download=False):super(TinyImageNet, self).__init__(root, transform=transform, target_transform=target_transform)self.dataset_path = os.path.join(root, self.base_folder)self.loader = default_loaderself.split = verify_str_arg(split, "split", ("train", "val",))if self._check_integrity():print('Files already downloaded and verified.')elif download:self._download()else:raise RuntimeError('Dataset not found. You can use download=True to download it.')if not os.path.isdir(self.dataset_path):print('Extracting...')extract_archive(os.path.join(root, self.filename))_, class_to_idx = find_classes(os.path.join(self.dataset_path, 'wnids.txt'))self.data = make_dataset(self.root, self.base_folder, self.split, class_to_idx)def _download(self):print('Downloading...')download_url(self.url, root=self.root, filename=self.filename)print('Extracting...')extract_archive(os.path.join(self.root, self.filename))def _check_integrity(self):return check_integrity(os.path.join(self.root, self.filename), self.md5)def __getitem__(self, index):img_path, target = self.data[index]image = self.loader(img_path)if self.transform is not None:image = self.transform(image)if self.target_transform is not None:target = self.target_transform(target)return image, targetdef __len__(self):return len(self.data)def find_classes(class_file):with open(class_file) as r:classes = list(map(lambda s: s.strip(), r.readlines()))classes.sort()class_to_idx = {classes[i]: i for i in range(len(classes))}return classes, class_to_idxdef make_dataset(root, base_folder, dirname, class_to_idx):images = []dir_path = os.path.join(root, base_folder, dirname)if dirname == 'train':for fname in sorted(os.listdir(dir_path)):cls_fpath = os.path.join(dir_path, fname)if os.path.isdir(cls_fpath):cls_imgs_path = os.path.join(cls_fpath, 'images')for imgname in sorted(os.listdir(cls_imgs_path)):path = os.path.join(cls_imgs_path, imgname)item = (path, class_to_idx[fname])images.append(item)else:imgs_path = os.path.join(dir_path, 'images')imgs_annotations = os.path.join(dir_path, 'val_annotations.txt')with open(imgs_annotations) as r:data_info = map(lambda s: s.split('\t'), r.readlines())cls_map = {line_data[0]: line_data[1] for line_data in data_info}for imgname in sorted(os.listdir(imgs_path)):path = os.path.join(imgs_path, imgname)item = (path, class_to_idx[cls_map[imgname]])images.append(item)return imagesif __name__ == '__main__':train_dataset = TinyImageNet('./tiny-imagenet', split='train', download=False)test_dataset = TinyImageNet('./tiny-imagenet', split='val', download=False)

MiniImageNet

##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Yaoyao Liu
## NUS School of Computing
## Email: yaoyao.liu@nus.edu.sg
## Copyright (c) 2019
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++import os
import random
import numpy as np
from tqdm import trange
import imageioclass MiniImageNetDataLoader(object):def __init__(self, shot_num, way_num, episode_test_sample_num, shuffle_images = False):self.shot_num = shot_numself.way_num = way_numself.episode_test_sample_num = episode_test_sample_numself.num_samples_per_class = episode_test_sample_num + shot_numself.shuffle_images = shuffle_imagesmetatrain_folder = './processed_images/train'metaval_folder = './processed_images/val'metatest_folder = './processed_images/test'npy_dir = './episode_filename_list/'if not os.path.exists(npy_dir):os.mkdir(npy_dir)self.npy_base_dir = npy_dir + str(self.shot_num) + 'shot_' + str(self.way_num) + 'way_' + str(episode_test_sample_num) + 'shuffled_' + str(self.shuffle_images) + '/'if not os.path.exists(self.npy_base_dir):os.mkdir(self.npy_base_dir)self.metatrain_folders = [os.path.join(metatrain_folder, label) \for label in os.listdir(metatrain_folder) \if os.path.isdir(os.path.join(metatrain_folder, label)) \]self.metaval_folders = [os.path.join(metaval_folder, label) \for label in os.listdir(metaval_folder) \if os.path.isdir(os.path.join(metaval_folder, label)) \]self.metatest_folders = [os.path.join(metatest_folder, label) \for label in os.listdir(metatest_folder) \if os.path.isdir(os.path.join(metatest_folder, label)) \]def get_images(self, paths, labels, nb_samples=None, shuffle=True):if nb_samples is not None:sampler = lambda x: random.sample(x, nb_samples)else:sampler = lambda x: ximages = [(i, os.path.join(path, image)) \for i, path in zip(labels, paths) \for image in sampler(os.listdir(path))]if shuffle:random.shuffle(images)return imagesdef generate_data_list(self, phase='train', episode_num=None):if phase=='train':folders = self.metatrain_foldersif episode_num is None:episode_num = 20000if not os.path.exists(self.npy_base_dir+'/train_filenames.npy'):print('Generating train filenames')all_filenames = []for _ in trange(episode_num):sampled_character_folders = random.sample(folders, self.way_num)random.shuffle(sampled_character_folders)labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)labels = [li[0] for li in labels_and_images]filenames = [li[1] for li in labels_and_images]all_filenames.extend(filenames)np.save(self.npy_base_dir+'/train_labels.npy', labels)np.save(self.npy_base_dir+'/train_filenames.npy', all_filenames)print('Train filename and label lists are saved')elif phase=='val':folders = self.metaval_foldersif episode_num is None:episode_num = 600if not os.path.exists(self.npy_base_dir+'/val_filenames.npy'):print('Generating val filenames')all_filenames = []for _ in trange(episode_num):sampled_character_folders = random.sample(folders, self.way_num)random.shuffle(sampled_character_folders)labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)labels = [li[0] for li in labels_and_images]filenames = [li[1] for li in labels_and_images]all_filenames.extend(filenames)np.save(self.npy_base_dir+'/val_labels.npy', labels)np.save(self.npy_base_dir+'/val_filenames.npy', all_filenames)print('Val filename and label lists are saved')elif phase=='test':folders = self.metatest_foldersif episode_num is None:episode_num = 600if not os.path.exists(self.npy_base_dir+'/test_filenames.npy'):print('Generating test filenames')all_filenames = []for _ in trange(episode_num):sampled_character_folders = random.sample(folders, self.way_num)random.shuffle(sampled_character_folders)labels_and_images = self.get_images(sampled_character_folders, range(self.way_num), nb_samples=self.num_samples_per_class, shuffle=self.shuffle_images)labels = [li[0] for li in labels_and_images]filenames = [li[1] for li in labels_and_images]all_filenames.extend(filenames)np.save(self.npy_base_dir+'/test_labels.npy', labels)np.save(self.npy_base_dir+'/test_filenames.npy', all_filenames)print('Test filename and label lists are saved')else:print('Please select vaild phase')def load_list(self, phase='train'):if phase=='train':self.train_filenames = np.load(self.npy_base_dir + 'train_filenames.npy').tolist()self.train_labels = np.load(self.npy_base_dir + 'train_labels.npy').tolist()elif phase=='val':self.val_filenames = np.load(self.npy_base_dir + 'val_filenames.npy').tolist()self.val_labels = np.load(self.npy_base_dir + 'val_labels.npy').tolist()elif phase=='test':self.test_filenames = np.load(self.npy_base_dir + 'test_filenames.npy').tolist()self.test_labels = np.load(self.npy_base_dir + 'test_labels.npy').tolist()elif phase=='all':self.train_filenames = np.load(self.npy_base_dir + 'train_filenames.npy').tolist()self.train_labels = np.load(self.npy_base_dir + 'train_labels.npy').tolist()self.val_filenames = np.load(self.npy_base_dir + 'val_filenames.npy').tolist()self.val_labels = np.load(self.npy_base_dir + 'val_labels.npy').tolist()self.test_filenames = np.load(self.npy_base_dir + 'test_filenames.npy').tolist()self.test_labels = np.load(self.npy_base_dir + 'test_labels.npy').tolist()else:print('Please select vaild phase')def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True):new_path_list = []new_label_list = []for k in range(batch_sample_num):class_idxs = list(range(0, self.way_num))random.shuffle(class_idxs)for class_idx in class_idxs:true_idx = class_idx*batch_sample_num + knew_path_list.append(input_filename_list[true_idx])new_label_list.append(input_label_list[true_idx])img_list = []for filepath in new_path_list:this_img = imageio.imread(filepath)this_img = this_img / 255.0img_list.append(this_img)if reshape_with_one:img_array = np.array(img_list)label_array = self.one_hot(np.array(new_label_list)).reshape([1, self.way_num*batch_sample_num, -1])else:img_array = np.array(img_list)label_array = self.one_hot(np.array(new_label_list)).reshape([self.way_num*batch_sample_num, -1])return img_array, label_arraydef one_hot(self, inp):n_class = inp.max() + 1n_sample = inp.shape[0]out = np.zeros((n_sample, n_class))for idx in range(n_sample):out[idx, inp[idx]] = 1return outdef get_batch(self, phase='train', idx=0):if phase=='train':all_filenames = self.train_filenameslabels = self.train_labels elif phase=='val':all_filenames = self.val_filenameslabels = self.val_labels elif phase=='test':all_filenames = self.test_filenameslabels = self.test_labelselse:print('Please select vaild phase')one_episode_sample_num = self.num_samples_per_class*self.way_numthis_task_filenames = all_filenames[idx*one_episode_sample_num:(idx+1)*one_episode_sample_num]epitr_sample_num = self.shot_numepite_sample_num = self.episode_test_sample_numthis_task_tr_filenames = []this_task_tr_labels = []this_task_te_filenames = []this_task_te_labels = []for class_k in range(self.way_num):this_class_filenames = this_task_filenames[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]this_class_label = labels[class_k*self.num_samples_per_class:(class_k+1)*self.num_samples_per_class]this_task_tr_filenames += this_class_filenames[0:epitr_sample_num]this_task_tr_labels += this_class_label[0:epitr_sample_num]this_task_te_filenames += this_class_filenames[epitr_sample_num:]this_task_te_labels += this_class_label[epitr_sample_num:]this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False)this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False)return this_inputa, this_labela, this_inputb, this_labelb

CINIC10

参考 CINIC10

import torchvision
import torchvision.transforms as transformscinic_directory = '/path/to/cinic/directory'
cinic_mean = [0.47889522, 0.47227842, 0.43047404]
cinic_std = [0.24205776, 0.23828046, 0.25874835]
cinic_train = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(cinic_directory + '/train',transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=cinic_mean,std=cinic_std)])),batch_size=128, shuffle=True)

pytorch读取常用数据集dataset实现例子相关推荐

  1. pytorch读取VOC数据集

    简单介绍VOC数据集 首先介绍下VOC2007数据集(下图是VOC数据集格式,为了叙述方便,我这里只放了两张图像) Main文件夹内的trainval.txt中的内容如下:存储了图像的名称不加后缀. ...

  2. 十分钟搞懂Pytorch如何读取MNIST数据集

    前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧- 正文 在阅读教程书籍<深度学习入门之Pytorch>时,文中是如此加载MNIST手写数字训练集的: ...

  3. CV:计算机视觉技最强学习路线之CV简介(传统视觉技术/相关概念)、早期/中期/近期应用领域(偏具体应用)、经典CNN架构(偏具体算法)概述、常用工具/库/框架/产品、环境安装、常用数据集、编程技巧

    CV:计算机视觉技最强学习路线之CV简介(传统视觉技术/相关概念).早期/中期/近期应用领域(偏具体应用).经典CNN架构(偏具体算法)概述.常用工具/库/框架/产品.环境安装.常用数据集.编程技巧 ...

  4. 如何用Pytorch读取自己的数据集

    在训练经典的数据集如cifar10,minsit等,可以用官方自带的数据集格式几行就写出来,如果是自己下载的数据集,那么我们应该如何用pytorch来读取呢?其实是有模板可以直接仿照着写的. 本次案例 ...

  5. PyTorch 学习笔记(一):让PyTorch读取你的数据集

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 Dataset类 ...

  6. pytorch 读取数据集(LiTS-肝肿瘤分割挑战数据集)

    pytorch 读取数据集 我的数据集长这样: xx.png和xx_mask.png是对应的待分割图像和ground truth 读取数据集 数据集对象被抽象为Dataset类,实现自定义的数据集需要 ...

  7. Dataset:数据集集合(综合性)——机器学习、深度学习算法中常用数据集大集合(建议收藏,持续更新)

    Dataset:数据集集合(综合性)--机器学习.深度学习算法中常用数据集大集合(建议收藏,持续更新) 目录 常规数据集 各大方向分类数据集汇总 具体数据集分类 相关文章 DL:关于深度学习常用数据集 ...

  8. Dataset:机器学习中常用数据集下载链接集合之详细攻略

    Dataset:机器学习中常用数据集下载链接集合之详细攻略 目录 机器学习中常用数据集下载链接集合之详细攻略 sklearn.datasets数据集所有csv文件集合 seaborn-data数据集所 ...

  9. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

最新文章

  1. loadrunner用javavuser进行接口测试
  2. 无监督算法与异常检测
  3. 【温故知新】CSS学习笔记(三大特性)
  4. hdu1005 Number Sequence(寻找循环节)
  5. 中国水痘带状疱疹感染治疗药物市场趋势报告、技术动态创新及市场预测
  6. 南开大学计算机专业考研经验贴,南开大学计算机考研初试经验
  7. Spring-BeanPostProcessor的执行顺序
  8. ORA-01033问题解决
  9. 【光模块、光接口及光纤知识】
  10. syswow64删除文件_syswow64,小编告诉你syswow64是什么文件夹
  11. I.MX6Q(TQIMX6Q/TQE9)学习笔记——新版BSP之根文件系统挂载
  12. NMAKE参考之二——运行NMAKE
  13. UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0xb5 in position 0: invalid start
  14. selenium爬取中国经济与社会发展统计数据库
  15. Startup is Ready,Geek to Startup!
  16. 心电图前波过多_心电图老也搞不清的那些波啊——δ波
  17. 基于Python的微信朋友圈数据可视化分析之个性签名
  18. Xinetd服务的安装与配置【转载】
  19. C++类模板怎么写在单独的头文件和源文件中
  20. c语言程序设计在哪讲,《C语言程序设计》讲.doc

热门文章

  1. 出走的门徒之二—摩拜 王晓峰:给岁月以文明
  2. 文件上载限制4gb_新get!百度网盘破除上传单个文件超4GB限制
  3. 帝国时代3手机单机版java_帝国时代3手游单机版
  4. 天堂祭祀php,test_《扶摇柳真真免费阅读》
  5. iOS比较好用的第三方框架
  6. element-ui中点击菜单,改变当前菜单背景颜色
  7. Failed to initialize NVML: Driver/library version mismatch
  8. wps一直显示正在备份怎么办_wps怎么设置和取消自动备份功能
  9. This National Puppy Day
  10. 视频教程-软考系统集成项目管理工程师视频教程(上)-软考