前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签本篇介绍输入的标签label亦为图像的数据集,并包含一些常用的处理手段。比如做图像语义分割时就会用到这种数据输入方式。

1、数据集简介

以VOC2012数据集为例,图像是RGB3通道的,label是1通道的,(其实label原来是几通道的无所谓,只要读取的时候转化成灰度图就行)。

训练数据:

语义label:

这里我们看到label图片都是黑色的,只有白色的轮廓而已

其实是因为label图片里的像素值取值范围是0 ~ 20,即像素点可能的类别共有21类(对此数据集来说),详情如下:

所以对于灰度值0---20来说,我们肉眼看上去就确实都是黑色的,因为灰度值太低了,而白色的轮廓的灰度值是255!

但是这些边界在计算损失值的时候是不作为有效值的,也就是对于灰度值=255的点是忽略的。

如果想看的话,可以用一些色彩变换,对0--20这每一个数字对应一个色彩,就能看出来了,示例如下

这不是重点,只是给大家看一下方便理解而已

2、文本信息

同样有一个文本来指导我对数据的读取,我的信息如下

这其实就是一个记载了图像ID的文本文档,连后缀都没有,但我们依然可以根据这个去数据集中读取相应的image和label

3、代码示例

这个代码是我自己在利用deeplabV2 跑semantic segmentation 任务时写的一个,也许写的并不优美,但反正是可以用的,

可以做个抛砖引玉的目的,对于才入门的朋友,理解这个思路就可,不必照搬我的代码风格……

import os
import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
import cv2
from PIL import Image
import torchvision.transforms as transforms
from torch.utils import dataclass VOCDataSet(data.Dataset):def __init__(self, root, list_path,  crop_size=(321, 321), mean=(104.008, 116.669, 122.675), mirror=True, scale=True, ignore_label=255):super(VOCDataSet,self).__init__()self.root = rootself.list_path = list_pathself.crop_h, self.crop_w = crop_sizeself.ignore_label = ignore_labelself.mean = np.asarray(mean, np.float32)self.is_mirror = mirrorself.is_scale = scaleself.img_ids = [i_id.strip() for i_id in open(list_path)]self.files = []for name in self.img_ids:img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)label_file = os.path.join(self.root, "SegmentationClassAug/%s.png" % name)self.files.append({"img": img_file,"label": label_file,"name": name})def __len__(self):return len(self.files)def __getitem__(self, index):datafiles = self.files[index]'''load the datas'''name = datafiles["name"]image = Image.open(datafiles["img"]).convert('RGB')label = Image.open(datafiles["label"]).convert('L')size_origin = image.size # W * H'''random scale the images and labels'''if self.is_scale: #如果我在定义dataset时选择了scale=True,就执行本语句对尺度进行随机变换ratio = 0.5 + random.randint(0, 11) // 10.0 #0.5~1.5out_h, out_w = int(size_origin[1]*ratio), int(size_origin[0]*ratio)# (H,W)for Resizeimage = transforms.Resize((out_h, out_w), Image.LANCZOS)(image)label = transforms.Resize((out_h, out_w), Image.NEAREST)(label)'''pad the inputs if their size is smaller than the crop_size'''pad_w = max(self.crop_w - out_w, 0)pad_h = max(self.crop_h - out_h, 0)img_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=0, padding_mode='constant')(image)label_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=self.ignore_label, padding_mode='constant')(label)out_size = img_pad.size'''random crop the inputs'''if (self.crop_h != 0 or self.crop_w != 0):#select a random start-point for croping operationh_off = random.randint(0, out_size[1] - self.crop_h)w_off = random.randint(0, out_size[0] - self.crop_w)#crop the image and the labelimage = img_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))label = label_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))'''mirror operation'''if self.is_mirror:if np.random.random() < 0.5:#0:FLIP_LEFT_RIGHT, 1:FLIP_TOP_BOTTOM, 2:ROTATE_90, 3:ROTATE_180, 4:or ROTATE_270.image = image.transpose(0)label = label.transpose(0)'''convert PIL Image to numpy array'''I = np.asarray(image,np.float32) - self.meanI = I.transpose((2,0,1))#transpose the  H*W*C to C*H*WL = np.asarray(np.array(label), np.int64)#print(I.shape,L.shape)return I.copy(), L.copy(), np.array(size_origin), name#这是一个测试函数,也即我的代码写好后,如果直接python运行当前py文件,就会执行以下代码的内容,以检测我上面的代码是否有问题,这其实就是方便我们调试,而不是每次都去run整个网络再看哪里报错
if __name__ == '__main__':DATA_DIRECTORY = '/home/teeyo/STA/Data/voc_aug/'DATA_LIST_PATH = '../dataset/list/val.txt'Batch_size = 4MEAN = (104.008, 116.669, 122.675)dst = VOCDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=(0,0,0))# just for test,  so the mean is (0,0,0) to show the original images.# But when we are training a model, the mean should have another valuetrainloader = data.DataLoader(dst, batch_size = Batch_size)plt.ion()for i, data in enumerate(trainloader):imgs, labels,_,_= dataif i%1 == 0:img = torchvision.utils.make_grid(imgs).numpy()img = img.astype(np.uint8) #change the dtype from float32 to uint8, because the plt.imshow() need the uint8img = np.transpose(img, (1, 2, 0))#transpose the Channels*H*W to  H*W*Channels#img = img[:, :, ::-1]plt.imshow(img)plt.show()plt.pause(0.5)#input()

我个人觉得我应该注释的地方都有相应的注释,虽然有点长, 因为实现了crop和翻转以及scale等功能,但是大家可以下去慢慢揣摩,理解其中的主要思路,与我前一篇的博文Pytorch创建自己的数据集1做对比,那篇博文相当于是提供了最基本的骨架,而这篇就在骨架上长肉生发而已,有疑问的欢迎评论探讨~~

Pytorch打怪路(三)Pytorch创建自己的数据集2相关推荐

  1. Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练

    pytorch进行CIFAR-10分类(4)训练 我的系列博文: Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理 Pytorch打怪路(一)pyt ...

  2. [深度学习] 分布式Pytorch介绍(三)

    [深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow介绍(二) [深度学习] 分布式Pytorch介绍(三) [深度学习] 分布式Horovod介绍(四)  一  Pytorc ...

  3. PyTorch框架学习三——张量操作

    PyTorch框架学习三--张量操作 一.拼接 1.torch.cat() 2.torch.stack() 二.切分 1.torch.chunk() 2.torch.split() 三.索引 1.to ...

  4. data后缀文件解码_小白学PyTorch | 17 TFrec文件的创建与读取

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 ...

  5. Pytorch机器学习(三)——VOC数据集转换为YOLO数据集

    Pytorch机器学习(三)--VOC数据集转换为YOLO数据集 目录 Pytorch机器学习(三)--VOC数据集转换为YOLO数据集 前言 一.yolo格式 二.代码 总结 前言 本文为利用pyt ...

  6. 深入浅出pytorch笔记——第三章,第四章

    文章目录 第三章 深度学习组成部分 配置环境 数据读取与加载 模型设计 损失函数 优化器 训练与评估 第四章(FMN分类实战) 4.1导入包 4.2配置训练环境和超参数 4.3数据读入与加载(Data ...

  7. 深度学习三(PyTorch物体检测实战)

    深度学习三(PyTorch物体检测实战) 文章目录 深度学习三(PyTorch物体检测实战) 1.网络骨架:Backbone 1.1.神经网络基本组成 1.1.1.卷积层 1.1.2.激活函数层 1. ...

  8. 裸机搭建深度学习服务器,ubuntu ssh服务器,pytorch, tensorflow, paddle三种框架安装。以及各种避雷。

    努力是为了不平庸~ 深度学习有些时候是枯燥的,AI的成长总是被说不可能,但是越是被说不可能,我们就越要创造无限可能!加油每一个AI人,我们会用行动来证明 我们的! rufurufus 需求知识点 提示 ...

  9. pytorch创建自己的数据集(分类任务)

    pytorch创建自己的数据集(分类任务) 转载于:https://www.cnblogs.com/cititude/p/11615158.html

最新文章

  1. Java统计1到300_java程序员的从0到1:统计某字符串在某文件中出现的次数(面试题)...
  2. 通过 UDP 发送数据的简单范例
  3. C++ string源码
  4. Summer Training day4 欧拉降幂
  5. c语言函数man,Linux下C语言编程有困难找man
  6. P2313 [HNOI2005]汤姆的游戏
  7. linux命令文本处理(一)grep
  8. 系统分析与设计课程项目总结
  9. Mirth Connect的简单使用
  10. 中小板企业上市要走哪些流程
  11. java excel换行_java poi出excel换行问题
  12. 这2个PDF转Word免费不限页数工具很多人没用过
  13. 商业数据分析-战略分析读后感
  14. Okra框架(一) 简介
  15. 客户之前使用的其他财务软件,现在需要把其他软件的财务凭证导入到用友T3软件中使用,如何能快速实现。
  16. 成为顶流平台后 新氧阳谋峥嵘显露
  17. 大哥要我实现天干地支的组合
  18. html5 video标签实现手机端视频播放全屏显示
  19. mysql数据库性能优化—my.cnf详解
  20. GB28181国标2016版本协议文档(正式版)解读(三)

热门文章

  1. 音频处理中的瞬态概念 Transient phenomena of Audio Signal Proccess
  2. 全国计算机等级考试python(刷题软件)
  3. C语言输入end时结束程序,c语言输入eof结束怎么写
  4. python 结束 serve_forever_如何使用Python脚本启动和停止包含“http.server.serveforever”的Python脚本...
  5. 大咖云集!9月18日 Imagination Technologies 受邀参加2020中关村论坛
  6. Java别踩白块外挂(附源码)
  7. 批量查询谷歌PR权重的方法有哪些?是什么影响着谷歌PR值?
  8. 企业微信加载html模板,企业微信公众号页面模板使用的方法是什么?
  9. 家里公司自动ip切换,批处理
  10. signature=cf2a4ebb3fc32cddedd659609006f5f5,Таджикистан. Трудныйпутьразвития...