Dataset自定义VOC2012数据集代码

from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etreeclass VOC2012DataSet(Dataset):"""读取解析PASCAL VOC2012数据集"""def __init__(self, voc_root, transforms, txt_name: str = "train.txt"):self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")self.img_root = os.path.join(self.root, "JPEGImages")self.annotations_root = os.path.join(self.root, "Annotations")# read train.txt or val.txt filetxt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)assert os.path.exists(txt_path), "not found {} file.".format(txt_name)with open(txt_path) as read:self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")for line in read.readlines()]# check fileassert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)for xml_path in self.xml_list:assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)# read class_indictjson_file = './pascal_voc_classes.json'assert os.path.exists(json_file), "{} file not exist.".format(json_file)json_file = open(json_file, 'r')self.class_dict = json.load(json_file)self.transforms = transformsdef __len__(self):return len(self.xml_list)def __getitem__(self, idx):# read xmlxml_path = self.xml_list[idx]with open(xml_path) as fid:xml_str = fid.read()xml = etree.fromstring(xml_str.encode('utf-8'))data = self.parse_xml_to_dict(xml)["annotation"]img_path = os.path.join(self.img_root, data["filename"])image = Image.open(img_path)if image.format != "JPEG":raise ValueError("Image '{}' format not JPEG".format(img_path))boxes = []labels = []iscrowd = []assert "object" in data, "{} lack of object information.".format(xml_path)for obj in data["object"]:xmin = float(obj["bndbox"]["xmin"])xmax = float(obj["bndbox"]["xmax"])ymin = float(obj["bndbox"]["ymin"])ymax = float(obj["bndbox"]["ymax"])# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nanif xmax <= xmin or ymax <= ymin:print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))continueboxes.append([xmin, ymin, xmax, ymax])labels.append(self.class_dict[obj["name"]])if "difficult" in obj:iscrowd.append(int(obj["difficult"]))else:iscrowd.append(0)# convert everything into a torch.Tensorboxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.as_tensor(labels, dtype=torch.int64)iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)image_id = torch.tensor([idx])area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])target = {}target["boxes"] = boxestarget["labels"] = labelstarget["image_id"] = image_idtarget["area"] = areatarget["iscrowd"] = iscrowdif self.transforms is not None:image, target = self.transforms(image, target)return image, targetdef get_height_and_width(self, idx):# read xmlxml_path = self.xml_list[idx]with open(xml_path) as fid:xml_str = fid.read()xml = etree.fromstring(xml_str)data = self.parse_xml_to_dict(xml)["annotation"]data_height = int(data["size"]["height"])data_width = int(data["size"]["width"])return data_height, data_widthdef parse_xml_to_dict(self, xml):"""将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dictArgs:xml: xml tree obtained by parsing XML file contents using lxml.etreeReturns:Python dictionary holding XML contents."""if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息return {xml.tag: xml.text}result = {}for child in xml:child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息if child.tag != 'object':result[child.tag] = child_result[child.tag]else:if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里result[child.tag] = []result[child.tag].append(child_result[child.tag])return {xml.tag: result}def coco_index(self, idx):"""该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理由于不用去读取图片,可大幅缩减统计时间Args:idx: 输入需要获取图像的索引"""# read xmlxml_path = self.xml_list[idx]with open(xml_path) as fid:xml_str = fid.read()xml = etree.fromstring(xml_str)data = self.parse_xml_to_dict(xml)["annotation"]data_height = int(data["size"]["height"])data_width = int(data["size"]["width"])# img_path = os.path.join(self.img_root, data["filename"])# image = Image.open(img_path)# if image.format != "JPEG":#     raise ValueError("Image format not JPEG")boxes = []labels = []iscrowd = []for obj in data["object"]:xmin = float(obj["bndbox"]["xmin"])xmax = float(obj["bndbox"]["xmax"])ymin = float(obj["bndbox"]["ymin"])ymax = float(obj["bndbox"]["ymax"])boxes.append([xmin, ymin, xmax, ymax])labels.append(self.class_dict[obj["name"]])iscrowd.append(int(obj["difficult"]))# convert everything into a torch.Tensorboxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.as_tensor(labels, dtype=torch.int64)iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)image_id = torch.tensor([idx])area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])target = {}target["boxes"] = boxestarget["labels"] = labelstarget["image_id"] = image_idtarget["area"] = areatarget["iscrowd"] = iscrowdreturn (data_height, data_width), target@staticmethoddef collate_fn(batch):return tuple(zip(*batch))import transforms
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import random# read class_indict
category_index = {}
try:json_file = open('pascal_voc_classes.json', 'r')class_dict = json.load(json_file)category_index = {v: k for k, v in class_dict.items()}
except Exception as e:print(e)exit(-1)data_transform = {"train": transforms.Compose([transforms.ToTensor(),transforms.RandomHorizontalFlip(0.5)]),"val": transforms.Compose([transforms.ToTensor()])
}voc_path = 'D:\\Pytorch\\zsffuture-yolov5-master'
# load train data set
train_data_set = VOC2012DataSet(voc_path, data_transform["train"], "train.txt")
print(len(train_data_set))
for index in random.sample(range(0, len(train_data_set)), k=5):img, target = train_data_set[index]img = ts.ToPILImage()(img)draw_box(img,target["boxes"].numpy(),target["labels"].numpy(),[1 for i in range(len(target["labels"].numpy()))],category_index,thresh=0.5,line_thickness=5)plt.imshow(img)plt.show()

pytorch经常使用的代码(持续更新)相关推荐

  1. Matlab常用代码---持续更新

    Matlab中的一些常用代码---持续更新 1. 获取当前的工作目录路径:添加文件夹到工作路径 2. 获取某个.m文件的绝对路径 3. 使用随机颜色进行可视化 1. 获取当前的工作目录路径:添加文件夹 ...

  2. C语言图形函数代码~持续更新中

    下面总结的是一些C语言图形函数代码~持续更新中 画三类圆 #include#include#include#include#includeint main(void) { initgraph(640, ...

  3. Android开发人员不得不收集的代码(持续更新中)(http://www.jianshu.com/p/72494773aace,原链接)

    Android开发人员不得不收集的代码(持续更新中) Blankj 关注 2016.07.31 04:22* 字数 370 阅读 102644评论 479喜欢 3033赞赏 14 utilcode D ...

  4. 课后习题代码持续更新。。。。。。。。。。。。。

    持续在更新! 转载于:https://www.cnblogs.com/PerZhu/p/10867519.html

  5. 工业视觉需要时可抄的代码---持续更新

    1.批量访问图片,等待键盘 for (int i = 2; i <= 23;){if (KEY_DOWN('S')){std::string path = "";char t ...

  6. 知道创宇爬虫题--代码持续更新中

    网上流传着知道创宇的一道爬虫题,虽然一直写着一些实用的爬虫,但真正写出这个一个规范要求的"工具",还是学到了不少东西.先看下题目: 使用python编写一个网站爬虫程序,支持参数如 ...

  7. 阿里开发10年大牛:Android开发人员不得不收集的代码(持续更新中)

    前言 1.软件吃掉世界,而机器学习正吃掉软件 在数据爆炸的时代,如何创建「智能系统」成为焦点.这些应用程序内所体现的智能技术,并非是将实用指令添加到代码中,而是可以让软件自己去识别真实世界中发生的事件 ...

  8. java程序员一天多少行有效代码,持续更新~

    Java程序员应该知道的20个有用的库经验丰富的优秀Java开发人员的一个特点是对API(包括JDK和第三方库)有广泛的了解.今天分享一些Java开发人员应该熟悉的最有用.最基本 程序员经常会因为不编 ...

  9. PyTorch:tensor、torch.nn、autograd、loss等神经网络学习手册(持续更新)

    PyTorch1:tensor2.torch.nn.autograd.loss等神经网络学习手册(持续更新) 链接:画图.读写图片 文章目录 一.tensor 二.完整训练过程:数据.模型.可学习参数 ...

  10. 1个人70万行代码,20年持续更新,这款游戏号称开发到死,永不停更

    梦晨 博雯 发自 凹非寺 量子位 报道 | 公众号 QbitAI 这是一款「开发到死」,「永不停更」的游戏. 兄弟两人,一人开发,一人剧情,共同维持了这款游戏近20年. 现在的玩家刚刚打开它,往往会发 ...

最新文章

  1. 需求分析的过程是什么?_7大需求分析方法与5大分析过程
  2. 基于visual Studio2013解决面试题之0802数字最多元素
  3. Ubuntu 20 04 提示“检测到系统程序出现问题”
  4. mysql 360 atlas_360 Atlas中间件安装及使用
  5. Python 语言 Hello world
  6. java 复制一个对象_Java如何完全复制一个对象
  7. PCQQ - 发送自定义的XML卡片消息
  8. linux内核支持浮点吗,浅谈linux kernel对于浮点运算的支持
  9. DDD如何区分实体和值对象
  10. OS - 浅谈操作系统的内存管理
  11. 在Word和OneNote中插入数学公式
  12. 应届毕业生工作7个月小结
  13. 【个人笔记】嵌入式多种通讯方式总结
  14. Electron桌面悬浮球工具,支持拖动及配置,提供了待办事项、快速笔记等功能。
  15. ifix5.8连接ab plc的点-通过igs
  16. 苹果cms官网源码下载
  17. JavaScript学习笔记(四)(DOM)
  18. 关于ElasticSearch的十道经典面试题
  19. Python资源汇总
  20. 【目标检测】基于yolov6的钢筋检测和计数(附代码和数据集)

热门文章

  1. VDI SolutionTrack - 上海站:11月20日
  2. fgets()逐行读取文件内容
  3. 问题十二:怎么用ray tracing画第一张图
  4. 数据资产管理直面企业哪些痛点
  5. arrayvalue php,phparrayvalue
  6. BOA软件服务的移植和BOA服务的配置
  7. java int a=b指向_java里int a=3,给a赋值的时候,是给它3的地址,还是直接赋值二进制3?...
  8. 第一章数据分析与挖掘概述
  9. java 有序不重复_Java中自定义有序不重复的集合——SetList
  10. 航拍+AI︱极简的视频风格迁移体验