目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上9.9节
此节功能为:语义分割和数据集
由于此节相对复杂,代码注释量较多

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/
# 锚框
# 注释:黄文俊
# E-mail:hurri_cane@qq.comfrom matplotlib import pyplot as plt
import time
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
from PIL import Image
from tqdm import tqdmimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l# 本函数已保存在d2lzh_pytorch中方便以后使用
def read_voc_images(root="F:/PyCharm/Learning_pytorch/data/VOCdevkit/VOC2012",is_train=True, max_num=None):txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')with open(txt_fname, 'r') as f:images = f.read().split()if max_num is not None:images = images[:min(max_num, len(images))]features, labels = [None] * len(images), [None] * len(images)for i, fname in tqdm(enumerate(images)):# tqdm主要作用是用于显示进度features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert("RGB")labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert("RGB")return features, labels # PIL imagevoc_dir = "F:/PyCharm/Learning_pytorch/data/VOCdevkit/VOC2012"
train_features, train_labels = read_voc_images(voc_dir, max_num=100)n = 5
imgs = train_features[0:n] + train_labels[0:n]
d2l.show_images(imgs, 2, n)
plt.show()# 列出标签中每个RGB颜色的值及其标注的类别
# 本函数已保存在d2lzh_pytorch中方便以后使用
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],[0, 64, 128]]
# 本函数已保存在d2lzh_pytorch中方便以后使用
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat','bottle', 'bus', 'car', 'cat', 'chair', 'cow','diningtable', 'dog', 'horse', 'motorbike', 'person','potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)
for i, colormap in enumerate(VOC_COLORMAP):colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i# 本函数已保存在d2lzh_pytorch中方便以后使用
def voc_label_indices(colormap, colormap2label):"""convert colormap (PIL image) to colormap2label (uint8 tensor)."""colormap = np.array(colormap.convert("RGB")).astype('int32')idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256+ colormap[:, :, 2])'''这里需要注意的是,label图片通过查看可以看到,其其实是有边框的,以第一张飞机图片举例整个图片颜色大致可分为:背景(黑色);飞机(红色);轮廓(白色)而通过colormap2label[idx]来索引的时候,之所以没返回轮廓信息,是因为轮廓在最开始的colormap2label 标签的构成中就没有设定轮廓的编号,默认为0,同背景一致'''return colormap2label[idx]# 以第一张飞机图片为例
y = voc_label_indices(train_labels[0], colormap2label)
print(y[105:115, 130:140], VOC_CLASSES[1])
# 以第四张鸟图片为例
bir = voc_label_indices(train_labels[3],colormap2label)
print(bir[360:380,30:50])
print("*" * 50)# 预处理数据(随机裁剪)
# 本函数已保存在d2lzh_pytorch中方便以后使用
def voc_rand_crop(feature, label, height, width):"""Random crop feature (PIL image) and label (PIL image)."""i, j, h, w = torchvision.transforms.RandomCrop.get_params(feature, output_size=(height, width))feature = torchvision.transforms.functional.crop(feature, i, j, h, w)label = torchvision.transforms.functional.crop(label, i, j, h, w)return feature, labeln = 5
# 裁剪次数设定为5
imgs = []
for _ in range(n):imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300)
d2l.show_images(imgs[::2] + imgs[1::2], 2, n)
plt.show()# 自定义语义分割数据集类
# 本函数已保存在d2lzh_pytorch中方便以后使用
class VOCSegDataset(torch.utils.data.Dataset):def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):"""crop_size: (h, w)"""self.rgb_mean = np.array([0.485, 0.456, 0.406])self.rgb_std = np.array([0.229, 0.224, 0.225])self.tsf = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=self.rgb_mean,std=self.rgb_std)])# 传入图片的标准化self.crop_size = crop_size  # (h, w)features, labels = read_voc_images(root=voc_dir,is_train=is_train,max_num=max_num)self.features = self.filter(features)   # PIL imageself.labels = self.filter(labels)       # PIL imageself.colormap2label = colormap2labelprint('read ' + str(len(self.features)) + ' valid examples')def filter(self, imgs):return [img for img in imgs if (img.size[1] >= self.crop_size[0] andimg.size[0] >= self.crop_size[1])]def __getitem__(self, idx):# 从内存中读取特征图和标签图feature, label = voc_rand_crop(self.features[idx], self.labels[idx],*self.crop_size)return (self.tsf(feature), # float32 tensorvoc_label_indices(label, self.colormap2label)) # uint8 tensordef __len__(self):return len(self.features)crop_size = (320, 480)
max_num = 100
voc_train = VOCSegDataset(True, crop_size, voc_dir, colormap2label, max_num)
voc_test = VOCSegDataset(False, crop_size, voc_dir, colormap2label, max_num)# 设批量大小为64,分别定义训练集和测试集的迭代器。
batch_size = 64
num_workers = 0 if sys.platform.startswith('win32') else 4
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True,drop_last=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(voc_test, batch_size, drop_last=True,num_workers=num_workers)for X, Y in train_iter:# X为原始数据,Y为标签print(X.dtype, X.shape)print(y.dtype, Y.shape)breakprint("*" * 50)

《动手学深度学习》(PyTorch版)代码注释 - 50 【Semantic_segmentation】相关推荐

  1. 伯禹公益AI《动手学深度学习PyTorch版》Task 04 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 04 学习笔记 Task 04:机器翻译及相关技术:注意力机制与Seq2seq模型:Transformer 微信昵称:WarmIce ...

  2. 伯禹公益AI《动手学深度学习PyTorch版》Task 07 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 07 学习笔记 Task 07:优化算法进阶:word2vec:词嵌入进阶 微信昵称:WarmIce 优化算法进阶 emmmm,讲实 ...

  3. 伯禹公益AI《动手学深度学习PyTorch版》Task 03 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 03 学习笔记 Task 03:过拟合.欠拟合及其解决方案:梯度消失.梯度爆炸:循环神经网络进阶 微信昵称:WarmIce 过拟合. ...

  4. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  5. 【动手学深度学习PyTorch版】12 卷积层

    上一篇移步[动手学深度学习PyTorch版]11 使用GPU_水w的博客-CSDN博客 目录 一.卷积层 1.1从全连接到卷积 ◼ 回顾单隐藏层MLP ◼ Waldo在哪里? ◼ 原则1-平移不变性 ...

  6. 【动手学深度学习PyTorch版】27 数据增强

    上一篇请移步[动手学深度学习PyTorch版]23 深度学习硬件CPU 和 GPU_水w的博客-CSDN博客 目录 一.数据增强 1.1 数据增强(主要是关于图像增强) ◼ CES上的真实的故事 ◼ ...

  7. 【动手学深度学习PyTorch版】13 卷积层的填充和步幅

    上一篇移步[动手学深度学习PyTorch版]12 卷积层_水w的博客-CSDN博客 目录 一.卷积层的填充和步幅 1.1 填充 1.2 步幅 1.3 总结 二.代码实现填充和步幅(使用框架) 一.卷积 ...

  8. 【动手学深度学习PyTorch版】23 深度学习硬件CPU 和 GPU

    上一篇请移步[动手学深度学习PyTorch版]22续 ResNet为什么能训练出1000层的模型_水w的博客-CSDN博客 目录 一.深度学习硬件CPU 和 GPU 1.1 深度学习硬件 ◼ 计算机构 ...

  9. 【动手学深度学习PyTorch版】15 池化层

    上一篇请移步[动手学深度学习PyTorch版]14 卷积层里的多输入多输出通道_水w的博客-CSDN博客 目录 一.池化层 1.1 池化层 ◼池化层原因 ◼ 二维最大池化 1.2 填充.步幅与多个通道 ...

  10. 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...

最新文章

  1. WebForm中DataGrid的20篇经典文章
  2. NOSQL系列-memcached安装管理与repcached高可用性
  3. 心态很容易受别人影响_为什么说缠论中的走势中枢容易影响短线买卖交易心态?...
  4. 栈在前端中的应用,顺便再了解下深拷贝和浅拷贝!
  5. C#中对注册表的操作指南
  6. 怎样阻止电脑开机自动安装大量垃圾软件
  7. App开发定制的种类:企业需要开发哪种App?
  8. 电话聊天狂人(25 分)(散列函数)
  9. jdbctemplate mysql blob_JdbcTemplate 操作Oracle Blob
  10. C语言学习-翁凯(目录总章)
  11. UVALive 4490 Help Bubu
  12. 【论文笔记】Deep Survival: A Deep Cox Proportional Hazards Network
  13. zoho邮箱收费和免费区别_集成MS Office和您的Zoho在线帐户
  14. python def -> : ->什么意思
  15. 情侣空间显示服务器失败,情侣空间error是什么意思
  16. 截取计算机全屏画面的方法有,电脑怎么截图全屏 详细方法介绍
  17. iwlwifi(AC9260)移植总结
  18. 关于易语言 无法加入dll命令 没有dll 的解决方式
  19. 如何把win10的计算机调至桌面,win10如何显示我的电脑在桌面?小编教你显示的方法...
  20. Javascript正则匹配数字,中英文,中横线,下划线,utf-8中文

热门文章

  1. Android 更换皮肤
  2. 关于python语言、下列说法不正确的是-模拟试卷C【单项选择题】
  3. StarUML使用心得
  4. JavaEE | 语言基础部分、对象与类
  5. 多少人,一边疯狂跳槽,一边疯狂后悔
  6. 省心!2021精选APP macOS装机必备清单来了
  7. linux定时脚本编写,如何实现Linux定时任务
  8. upc 9367 雷涛的小猫
  9. JavaScript的prototype是什么?
  10. 驱动开发之五 --- TDI之一(飞雪楚狂人)