一.目标检测中的锚框

前提:

本节锚框代码实现,使用了很多Pytorch内置函数,如果有对应函数看不懂的地方,可以查看前面博客对相应函数的具体解释,如下链接所示:

  1. Pytorch中torch.meshgrid()函数解析
  2. Pytorch中torch.stack() 函数解析
  3. Pytorch中torch.cat()函数解析
  4. Pytorch中tensor.T(torch.T)解析
  5. Pytorch中torch.repeat()函数解析
  6. Pytorch中torch.repeat_interleave()函数解析
  7. Pytorch中torch.unsqueeze()和torch.squeeze()函数解析
  8. Pytorch中torch.sort()和torch.argsort()函数解析
  9. Pytorch中torch.max()函数解析
  10. Pytorch中[:,None]的用法解析
  11. Pytorch中torch.argmax()函数解析
  12. Pytorch中torch.nonzero()函数解析
  13. Pytorch中torch.full(),torch.ones()和torch.zeros()函数解析
  14. Pytorch中torch.numel(),torch.shape,torch.size()和torch.reshape()函数解析
  15. Pytorch中的广播机制(Broadcast)
  16. Pytorch中的广播机制(Broadcast)

1. 概念

目标检测算法通常会在输入图像中采样大量的区域,然后判断这些区域中是否包含我们感兴趣的目标,并调整区域边界从而更准确地预测目标的真实边界框(ground-truth bounding box)。 不同的模型使用的区域采样方法可能不同。 这里介绍其中的一种方法:以每个像素为中心,生成多个缩放比和宽高比(aspect ratio)不同的边界框,这些边界框被称为锚框(anchor box)。

2. 生成多个不同形状的锚框

2.1 假设输入图像的高度为hhh,宽度为www,以图像的每个像素为中心生成不同形状的锚框:缩放比为s∈(0,1]s\in (0, 1]s∈(0,1],宽高比为r>0r > 0r>0。那么锚框的宽度和高度分别是wsrws\sqrt{r}wsr​和hs/rhs/\sqrt{r}hs/r​注意当锚框中心位置给定时,已知宽和高的锚框是确定的。
2.2 要生成多个不同形状的锚框,需要设置许多缩放比(scale)取值s1,…,sns_1,\ldots, s_ns1​,…,sn​和许多宽高比(aspect ratio)取值r1,…,rmr_1,\ldots, r_mr1​,…,rm​。当使用这些比例和长宽比的所有组合以每个像素为中心时,输入图像将总共有whnmwhnmwhnm个锚框。尽管这些锚框可能会覆盖所有真实边界框,但计算复杂性很容易过高。在实践中,只考虑包含s1s_1s1​或r1r_1r1​的组合
((s1,r1),(s1,r2),…,(s1,rm),(s2,r1),(s3,r1),…,(sn,r1).(s_1, r_1), (s_1, r_2), \ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \ldots, (s_n, r_1).(s1​,r1​),(s1​,r2​),…,(s1​,rm​),(s2​,r1​),(s3​,r1​),…,(sn​,r1​).)
也即是以同一像素为中心的锚框的数量是n+m−1n+m-1n+m−1。对于整个输入图像,我们将共生成wh(n+m−1)wh(n+m-1)wh(n+m−1)个锚框。
上述生成锚框的方法在下面的multibox_prior()函数中实现,指定输入图像、尺寸列表和宽高比列表,然后此函数将返回所有像素的锚框。

import torch
import d2l.torch
torch.set_printoptions(2) # 精简输出精度
"""生成以每个像素为中心具有不同形状的锚框"""
def multibox_prior(data,sizes,ratios):#图片的高和宽in_height,in_width=data.shape[-2:]device,num_sizes,num_ratios = data.device,len(sizes),len(ratios)tensor_sizes,tensor_ratios = torch.tensor(sizes,device=device),torch.tensor(ratios,device=device)#每个像素点pixel的锚框数boxes_per_pixel = (num_sizes+num_ratios-1)# 为了将锚点移动到像素的中心,需要设置偏移量。# 因为一个像素的的高为1且宽为1,我们选择偏移我们的中心0.5offset_w,offset_h = 0.5,0.5#把图片高和宽归一化(缩放)到1steps_h,steps_w = 1.0/in_height,1.0/in_width # 在y轴上缩放步长,在x轴上缩放步长'''生成锚框的所有中心点'''#生成0.5/in_height到(in_height+0.5)/in_height之间高度的刻度值,代表y轴center_h = (torch.arange(in_height,device=device)+offset_h) * steps_h#生成0.5/in_width到(in_width+0.5)/in_width之间宽度的刻度值,代表x轴center_w = (torch.arange(in_width,device=device)+offset_w) * steps_w#生成网格,shift_y中行相等,列不等,shift_x中行不等,列相等,行的个数为center_h的元素个数,列的个数为center_w的元素个数shift_y,shift_x = torch.meshgrid(center_h,center_w,indexing='ij')#shift_y,shift_x都reshape成一维,维数为:in_height*in_widthshift_y = shift_y.reshape(-1)shift_x = shift_x.reshape(-1)'''每个中心点都将有“boxes_per_pixel”个锚框,所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次'''# 取(shift_x,shift_y,shift_x,shift_y)相应元素组成一行,因此一行有四个元素,然后将这一行直接复制boxes_per_pixel次,表示一个像素点的坐标(4维,因为用于后面计算左上右下的坐标)复制boxes_per_pixel次(因为需要生成boxes_per_pixel)锚框,out_grid.size=(561x728x5,4)out_grid = torch.stack((shift_x,shift_y,shift_x,shift_y),dim=1).repeat_interleave(boxes_per_pixel,dim=0)'''生成“boxes_per_pixel”个高和宽,之后用于创建锚框的四角坐标(x_min,y_min,x_max,y_max)'''#生成锚框的w,有boxes_per_pixel个锚框,因此有boxes_per_pixel个锚框的宽anchors_w = torch.cat((tensor_sizes*torch.sqrt(tensor_ratios[0]),tensor_sizes[0]*torch.sqrt(tensor_ratios[1:])))*in_height/in_width#生成锚框的h,有boxes_per_pixel个锚框,因此有boxes_per_pixel个锚框的高anchors_h = torch.cat((tensor_sizes/torch.sqrt(tensor_ratios[0]),tensor_sizes[0]/torch.sqrt(tensor_ratios[1:])))#每一行代表一个像素点的锚框的高和宽,因为一个像素点有boxes_per_pixel个锚框,因此每boxes_per_pixel行代表一个像素的所有锚框。因为所有像素点的锚框个数和高宽都是一样的,因此需要复制in_height*in_width次,所以anchor_manipulations.size=(5x561x728,4)anchor_manipulations = torch.stack((-anchors_w,-anchors_h,anchors_w,anchors_h)).T.repeat(in_height*in_width,1)/2 #除以2来获得半高和半宽#因此out_grid与anchor_manipulations相加得到一个像素点中一个锚框的左上,右下的坐标,因此每boxes_per_pixel行代表一个像素点的所有锚框的左上,右下坐标值,也相当于生成所有像素点的所有锚框output = out_grid+anchor_manipulations#output新增一个维度return output.unsqueeze(0)

multibox_prior()函数里面一些变量如下图所示,可以用于理解。注意size指的是图像长宽的缩放比例而非图像面积的缩放比例,ratio是指锚框的宽高比,指的是将原图像归一化为正方形后截取的锚框的宽高比,或者说是在原图像的宽高比基础上乘以ratio,才是真正的锚框的宽高比。上面代码中计算anchors_w时为什么需要再乘以(in_height/in_width),原因参考下面链接:

  1. 计算anchors_w乘以(in_height/in_width)原因
  2. 计算anchors_w乘以(in_height/in_width)原因

2.3 返回的锚框变量output的形状是(批量大小,锚框的数量,4)。

img = d2l.torch.plt.imread('../images/catdog.jpg')
h,w = img.shape[:2]
data = torch.rand(size=(1,3,h,w))
output = multibox_prior(data,sizes=[0.75,0.5,0.25],ratios=[1,2,0.5])
#返回的锚框变量output的形状是(批量大小,锚框的数量,4)。
print(output.shape)
print(h,w)
输出结果如下:
torch.Size([1, 2042040, 4])
561 728

2.4 将锚框变量Y的形状更改为(图像高度,图像宽度,以同一像素为中心的锚框的数量,4)后,可以获得以指定像素的位置为中心的所有锚框,访问以(250,250,0,:)为中心的第一个锚框 ,它有四个元素:锚框左上角的 (

李沐动手学深度学习v2-目标检测中的锚框和代码实现相关推荐

  1. 李沐动手学深度学习v2/总结1

    总结 编码过程 数据 数据预处理 模型 参数,初始化参数 超参数 损失函数,先计算损失,清空梯度(防止有累积的梯度),再对损失后向传播计算损失关于参数的梯度 优化算法,使用优化算法更新参数 训练求参数 ...

  2. 李沐动手学深度学习V2-目标检测边界框

    一. 目标检测边界框 加载本节使用的示例图像,可以看到图像左边是一只狗,右边是一只猫,它们是这张图像里的两个主要目标,如下图所示. import torch import d2l import d2l ...

  3. 14李沐动手学深度学习v2/权重衰退简洁实现

    # 权重衰退是广泛应用的正则化技术 %matplotlib inline import torch from torch import nn from d2l import torch as d2l ...

  4. 李沐动手学深度学习V2-全卷积网络FCN和代码实现

    一.全卷积网络FCN 1. 介绍 语义分割是对图像中的每个像素分类,全卷积网络(fully convolutional network,FCN)采用卷积神经网络实现了从图像像素到像素类别的变换 ,与前 ...

  5. 李沐动手学深度学习(pytorch版本)d2lzh_pytorch包的缺少安装问题

    学习深度学习时候,很多人参考的是李沐的动手学深度学习Pytorch版本(附上官方地址:https://tangshusen.me/Dive-into-DL-PyTorch/#/). 在学习3.5.1节 ...

  6. 【李沐动手学深度学习】读书笔记 01前言

    虽然之前已经学过这部分内容和深度学习中的基础知识,但总觉得学的不够系统扎实,所以希望再通过沐神的课程以及书籍,系统条理的学习一遍.在读书过程中,利用导图做了一下梳理,形成了这个读书笔记.如有侵权,请联 ...

  7. 关于李沐动手学深度学习(d2l)pytorch环境本地配置

    本地安装d2l 由于之前试了很多次d2l课本的安装方法失败了,这里提供一种我可以成功安装d2l包的方法. pytorch安装 首先安装cuda.cudnn.pytroch(gpu版本).可以参考这篇文 ...

  8. 动手学深度学习之目标检测基础

    参考伯禹学习平台<动手学深度学习>课程内容内容撰写的学习笔记 原文链接:https://www.boyuai.com/elites/course/cZu18YmweLv10OeV/less ...

  9. 李沐动手学深度学习V2-多尺度目标检测

    一. 多尺度目标检测 以输入图像的每个像素为中心,生成多个锚框,这些锚框代表了图像不同区域的样本. 然而,如果为每个像素都生成的锚框,最终可能会得到太多需要计算的锚框. 想象一个 561×728 的输 ...

最新文章

  1. 小学生python入门-周边 | 小学生都开始学Python了,你还在等什么?
  2. 密码找回功能可能存在的问题
  3. Flash基本工具练习
  4. 教你如何窃取网络信息
  5. CentOS关机大法之shutdown命令格式
  6. HTML的基本知识-和常用标签-以及相对路径和绝对路径的区别
  7. c#餐饮系统打印机_C#实现打印机功能
  8. 进销存系统测试实战-功能测试
  9. SSM大学生心理健康服务平台毕业设计-附源码071131
  10. 基于Java技术的汽车维修管理软件的设计与实现
  11. 哪里有周末java培训_北京哪里里Java周末学习班
  12. LoadRunner11代理在Win10操作系统启动不起来,或者报错:该内存不能为written
  13. LINUX编译xcb/xcb-proto
  14. td nowrap 属性 中多个input 不换行 水平排列
  15. Boost搜索引擎项目
  16. A*算法 JAVA实现
  17. 920quiz+922复杂度+927quiz2
  18. Glide系列(四) — Glide缓存流程分析
  19. 导学目录-学如逆水行舟
  20. python写socket代理_为python设置socket代理的方法

热门文章

  1. spider_爬取斗图啦所有表情包(图片保存)
  2. Scrapy-Redis使用教程将现有爬虫修改为分布式爬虫
  3. Ubuntu安装Gcc时,显示“无法解析域名cn.archive.ubuntu.com”,如下方式可解决
  4. android:详细解读DialogFragment
  5. vue学习-v-if v-for优先级、data、key、diff算法、vue组件化、vue设计原则、组件模板只有一个根元素、MVC.MVP,MVVM
  6. 谷歌翻译退出,idea谷歌翻译无法使用(解决)
  7. Tree Traversal(二叉树的遍历)
  8. Mac上将mp4视频做成屏保
  9. BRAF蛋白F595S G615R突变的影响
  10. KettleError connecting to database: (using class org.gjt.mm.mysql.Driver)Communications link failure