retinaface代码讲解_Pytorch-RetinaFace 详解
参考:
上面那套代码非常的详细,数据集都准备好了,下载后就可以直接训练了。
下面简单对代码进行梳理
Pytorch的实现是忽略了3D建模的那一套东西,有需要的自己看论文哈。
1、网络模型
字丑莫怪
先看看上图有个大概的概念,差不多就是FPN+SSH_Module+MultiLoss。
提一嘴哈,如果希望将retinaface转成onnx然后走tensorRT的话,需要改一下里面FPN使用的上采样方法,从F.interpolate改成nn.ConvTranspose2d.
参考中的代码,模型放在./models/retinaface.py中。我们来看一看,主要看__init__和forward
class RetinaFace(nn.Module):
def __init__(self, cfg = None, phase = 'train'):
"""
:param cfg: Network related settings.
:param phase: train or test.
"""
super(RetinaFace,self).__init__()
self.phase = phase
'''
1、定义了backbone。
2、拿到了中间的若干层结果作为backbone的输出
'''
# 省略了一下backbone的判断,直接用mobilev1
backbone = MobileNetV1()
# {'stage1': 1, 'stage2': 2, 'stage3': 3}
self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
in_channels_stage2 = cfg['in_channel']
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = cfg['out_channel'] # 64
self.fpn = FPN(in_channels_list,out_channels)
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
classhead = nn.ModuleList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels,anchor_num))
return classhead
def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
bboxhead = nn.ModuleList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels,anchor_num))
return bboxhead
def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
landmarkhead = nn.ModuleList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels,anchor_num))
return landmarkhead
def forward(self,inputs):
# 输入过一遍backbone,得到out
# out : [o1,o2,o3]这样的
out = self.body(inputs)
# FPN
# FPN的输出也是[of1,of2,of3]这样的
fpn = self.fpn(out)
# SSH
# FPN的每个分辨率的输出,都过一遍SSH中Context_Module
feature1 = self.ssh1(fpn[0])
feature2 = self.ssh2(fpn[1])
feature3 = self.ssh3(fpn[2])
features = [feature1, feature2, feature3]
'''
对每个分支的结果,分别输入bbox,classification和landmarks的头中
得到预测结果
'''
bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
if self.phase == 'train':
output = (bbox_regressions, classifications, ldm_regressions)
else:
output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
return output
模型还是很简单的,就不多看了。
2、MultiLoss
这个真的是看吐了…
参考中loss的链接,虽然讲的是SSD的,但是里面用的loss是一致的,这里也可以看哈。
因为loss的代码量实在太多了,这里总结一下:
1、在3个尺度上,模型一共会有16800个anchors。
2、训练的时候,我们需要选择若干个anchors,让基于这些被选中的anchors的预测结果,参与loss的计算。
基本的选择规则是,在anchors和某个ground_truth的overlap大于阈值。
3、对于bbox和landmarks的回归而言,需要把原来基于原图归一化的位置参数,转化为基于anchors的归一化参数,然后计算Smooth-L1-LOSS即可。
4、对于classification而言,2中筛选出来的都是正样本,因此还需要选择一定的负样本,来训练分类任务。
以上就是loss设计的总思路了,下面我们来看代码。
以输入为[1,3,640,640]为例子,ground_truth中,bbox数量为num_ground_truth = 11)找到被选中的anchors,主要内容是在utils/box_utils.py中的match里面
class MultiBoxLoss(nn.Module):
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
super(MultiBoxLoss, self).__init__()
。。。。。。
def forward(self, predictions, priors, targets):
# 坐标,前背景,landmarks
# [1,16800,4],[1,16800,2],[1,16800,10]
loc_data, conf_data, landm_data = predictions
# [16800,4]
priors = priors
num = loc_data.size(0) # batchsize
num_priors = (priors.size(0)) # num_anchors
# match priors (default boxes) and ground truth boxes
# 这几是预备存储内容的
loc_t = torch.Tensor(num, num_priors, 4)
landm_t = torch.Tensor(num, num_priors, 10)
conf_t = torch.LongTensor(num, num_priors)
for idx in range(num):
# 对batch中的每一个内容
truths = targets[idx][:, :4].data # 坐标
labels = targets[idx][:, -1].data # 置信度
landms = targets[idx][:, 4:14].data # landmarks
defaults = priors.data
# 最后结果都在loc_t,conf_t,landm_t里面
# 一句话,找到每个anchors该负责的gt框,
# 并将该gt框转化为对应与该anchors坐标的归一化参数
match(self.threshold, truths, defaults, self.variance, labels, landms,loc_t, conf_t, landm_t,idx)2)match,第一部分:
下面这一部分,通过overlap阈值过滤的方式,找到了那些被选中的anchors,并在他们的位置上标志为非0.
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
# 计算ground_truth和先验框之间的overlap。
# [num_ground_truth, num_prior]
# [1, 16800]
overlaps = jaccard(
truths,
point_form(priors)
)
'''
best_prior_overlap: [num_ground_thrth,1],放的是overlap的值
best_prior_idx:[num_ground_thrth,1],放的是最大值的id
表示,每一个gt选了一个和它overlap最大的anchors
'''
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# ignore hard gt
# gt框如果和anchors的overlap太小,那就不要这个gt框
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
# 如果没有剩下就算了
loc_t[idx] = 0
conf_t[idx] = 0
return
# [1,num_priors] best ground truth for each prior
'''
best_truth_overlap: [1,num_prior_box],overlap
best_truth_idx: [1,num_prior_box],放的是最大值的id
表示,每一个anchors选了一个和它overlap最大的gt
'''
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
# 把多余的维度干掉
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_overlap.squeeze_(1)
# 过滤过的框
best_prior_idx_filter.squeeze_(1)
# 那过滤过的框,把best_prior_idx_filter上的数字都写为2
# 找到最合适的prior,最合适的anchors
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
# 把被选中的框补充进去,以防漏了
for j in range(best_prior_idx.size(0)):
# best_prior_idx[j]:对于每一个gt框而言,拿到和它overlap最高的anchors的id
# best_truth_idx[best_prior_idx[j]] = j:对于anchors,指定它对应的gt框id
best_truth_idx[best_prior_idx[j]] = j
# best_truth_idx中所有非零的位置,都表示这是一个被选中的框3)match,第二部分 先看懂下面那个例子:
a = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
b = torch.Tensor([0,1,0,1,0,1,0]).long()
a[b]:
tensor([[1., 2., 3.],
[4., 5., 6.],
[1., 2., 3.],
[4., 5., 6.],
[1., 2., 3.],
[4., 5., 6.],
[1., 2., 3.]])
看懂上面那个例子后,下面的一些写法就比较好理解了
# 以best_truth_idx中的数字作为索引,
# 生成了一个[len(best_truth_idx),len(truth)]的张量
# 这样对每个anchors指定了一个和它overlap最大的gt位置
matches = truths[best_truth_idx]
print(matches[0])
# 这个类似
conf = labels[best_truth_idx]
conf[best_truth_overlap < threshold] = 0
# 把位置基于anchors进行编码
# 把基于图像的位置比例,调整到基于这个anchors的位置比例
# variance用来放缩loss的,使得multiloss中的各个位置的loss差不多大
loc = encode(matches, priors, variances)
matches_landm = landms[best_truth_idx]
# 一样的思路,landmarks中,将原来基于图像的归一化位置,转到基于anchrors中心的归一化位置。
landm = encode_landm(matches_landm, priors, variances)
# 放到各自的batch中去
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
conf_t[idx] = conf # [num_priors] top class label for each prior
landm_t[idx] = landm4)match看完了,我们看后面计算loss的部分:
.........
match(self.threshold, truths, defaults, self.variance, labels, landms,
loc_t, conf_t, landm_t,
idx)
zeros = torch.tensor(0)
# landm Loss (Smooth L1)
# Shape: [batch,num_priors,10]
pos1 = conf_t > zeros
# 拿到anchors中,和gt框的overlap大于阈值的anchors数目
num_pos_landm = pos1.long().sum(1, keepdim=True)
# 拿到所有batch中,最大的num_pos_landm
N1 = max(num_pos_landm.data.sum().float(), 1)
# 把1*16800变成1*16800*10
# 单纯的复制了每一位置10遍
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
'''
landmark_loss
'''
# 过滤,并拿到预测结果中,对应的landmarks
landm_p = landm_data[pos_idx1].view(-1, 10)
# 过滤,并拿到labels中对应的数据
landm_t = landm_t[pos_idx1].view(-1, 10)
# 根据预测结果和target结果,计算landmarks的loss
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
'''
bbox_loss
'''
# 和之前那个>zeros过滤一样
pos = conf_t != zeros
conf_t[pos] = 1
# 同理,将过滤数从1*16800,放大道1*16800*4
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
# 拿到过滤后的预测数据
loc_p = loc_data[pos_idx].view(-1, 4)
# 拿到过滤后的label数据
loc_t = loc_t[pos_idx].view(-1, 4)
# 计算bbox的F1loss
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
'''
1、根据阈值过滤,拿到正样本的位置pos和数量num_pos
2、num_pos*7拿到负样本数量num_neg。
3、根据loss_c,拿到loss最大的num_neg个anchors的位置neg。
4、从conf_t和conf_d中,拿出pos和neg对应的anchors,计算cla_loss
'''
# Compute max conf across batch for hard negative mining
# [16800*batch_size,2]
batch_conf = conf_data.view(-1, self.num_classes)
'''
这个是用来排序用的
可以理解为,每个anchors会预测两个概率,【p1,p2】,
并且,根据gt,我们知道它应该得到的是[0,1]
那么,这个anchors的loss=log(e^p1+e^p2)-p2
loss_c就存了这样的16800个这样的loss
'''
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
# 总不能用所有的16800个loss来计算
# 被选中的anchors的loss写为0
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
# 重新按照[batch_size,16800]进行reshape
loss_c = loss_c.view(num, -1)
'''
a = torch.Tensor([[10,50,70,20,30,40]])
_,b = a.sort(1,descending=True) # b=tensor([[2, 1, 5, 4, 3, 0]])
_,c = b.sort(1) # c=tensor([[5, 1, 0, 4, 3, 2]])
'''
# 每个batch单独,按照value大小,降序排序
_, loss_idx = loss_c.sort(1, descending=True) #对每张图的priorbox的conf loss从大到小排序,每一列的值为prior box的index;相当于最不是前景的排在第一个
_, idx_rank = loss_idx.sort(1) # 对上面每一列,按照存储内容大小进行排序
print('idx',idx_rank.shape,idx_rank) # idx_rank为在loss_idx中的位置
# num_pos是正样本数量,pos = conf_t != zeros
num_pos = pos.long().sum(1, keepdim=True)
# num_neg是负样本数量,最大不超过总样本数量
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
# 拿到负样本id
'''
neg中拿到的是idx_rank中从0~num_neg的这些下标
这些下标对应的是loss_idx中loss比较的位置
'''
neg = idx_rank < num_neg.expand_as(idx_rank)
# Confidence Loss Including Positive and Negative Examples
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
# 拿到prediction中的数据
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
# 拿到labels中的数据
targets_weighted = conf_t[(pos+neg).gt(0)]
print(conf_p.shape,targets_weighted.shape)
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
# 归一化
N = max(num_pos.data.sum().float(), 1)
loss_l /= N
loss_c /= N
loss_landm /= N1
return loss_l, loss_c, loss_landm
3、推理
# img的预处理就没放进来
loc, conf, landms = net(img) # forward pass
# 根据输入图片大小的不同计算需要使用的priorbox
# 实际中输入大小一定就不需要了
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
priors = priors.to(device)
prior_data = priors.data
# 将bbox从基于anchors的情况下解码到在原图中的位置
boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
boxes = boxes * scale / resize
boxes = boxes.cpu().numpy()
# 处理classification预测结果
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
# 将landmarks从基于anchors的情况下解码到在原图中的位置
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])
scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2], img.shape[3], img.shape[2],
img.shape[3], img.shape[2]])
scale1 = scale1.to(device)
landms = landms * scale1 / resize
landms = landms.cpu().numpy()
# ignore low scores
# 直接阈值过滤
inds = np.where(scores > args.confidence_threshold)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]
# keep top-K before NMS
# order = scores.argsort()[::-1][:args.top_k]
order = scores.argsort()[::-1]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]
# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, args.nms_threshold)
dets = dets[keep, :]
landms = landms[keep]
# keep top-K faster NMS
# dets = dets[:args.keep_top_k, :]
# landms = landms[:args.keep_top_k, :]
dets = np.concatenate((dets, landms), axis=1)
retinaface代码讲解_Pytorch-RetinaFace 详解相关推荐
- 泛型java 代码讲解_Java泛型详解
2516326-5475e88a458a09e4.png 一,打破砂锅问到底 泛型存在的意义? 泛型类,泛型接口,泛型方法如何定义? 如何限定类型变量? 泛型中使用的约束和局限性有哪些? 泛型类型的继 ...
- DL之YoloV3:Yolo V3算法的简介(论文介绍)、各种DL框架代码复现、架构详解、案例应用等配图集合之详细攻略
DL之YoloV3:Yolo V3算法的简介(论文介绍).各种DL框架代码复现.架构详解.案例应用等配图集合之详细攻略 目录 Yolo V3算法的简介(论文介绍) 0.YoloV3实验结果 1.Yol ...
- java构造块_java中的静态代码块、构造代码块、构造方法详解
运行下面这段代码,观察其结果: package com.test; public class HelloB extends HelloA { public HelloB() { } { System. ...
- python代码覆盖率测试_unittest+coverage单元测试代码覆盖操作实例详解_python
这篇文章主要为大家详细介绍了unittest+coverage单元测试代码覆盖操作的实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下 基于上一篇文章,这篇文章是关于使用coverage来实现代码 ...
- yolov5——detect.py代码【注释、详解、使用教程】
yolov5--detect.py代码[注释.详解.使用教程] yolov5--detect.py代码[注释.详解.使用教程] 1. 函数parse_opt() 2. 函数main() 3. 函数ru ...
- python的爱心曲线公式_六行python代码的爱心曲线详解
前些日子在做绩效体系的时候,遇到了一件囧事,居然忘记怎样在Excel上拟合正态分布了,尽管在第二天重新拾起了Excel中那几个常见的函数和图像的做法,还是十分的惭愧.实际上,当时有效偏颇了,忽略了问题 ...
- 消除冗长Java代码的工具——Lombok详解
消除冗长Java代码的工具--Lombok详解 文章目录 消除冗长Java代码的工具--Lombok详解 什么是Lombok Lombok的作用 Lombok常用注解 Lombok安装 什么是Lomb ...
- 国际C语言混乱代码大赛优胜作品详解之“A clock in one line
国际C语言混乱代码大赛优胜作品详解之"A clock in one line" 发表于2013-04-11 17:22| 9419次阅读| 来源StackOverflow| 53 ...
- vc读取北通手柄按键_噬血代码手柄怎么操作 噬血代码北通手柄按键功能详解-游侠网...
噬血代码手柄怎么操作?应该很多朋友都还不是很清楚吧,所以呢小编今天给大家带来的就是噬血代码北通手柄按键功能详解,需要的朋友不妨进来看看. 北通手柄按键功能详解 游戏介绍 本作是由<噬神者> ...
- yolov5——train.py代码【注释、详解、使用教程】
yolov5--train.py代码[注释.详解.使用教程] yolov5--train.py代码[注释.详解.使用教程] yolov5--train.py代码[注释.详解.使用教程] 前言 1. p ...
最新文章
- HTMLButton控件下的Confirm()
- 【错误记录】NDK 动态库报错 ( dlopen failed: file offset for the library /lib/arm64/libwebp.so“ >= file size:0)
- Python使用中文注释和输出中文(原创)
- jsp中request.getAttributeNames()报红
- AutoMapper的介绍与使用(二)
- java ee jstl_Java EE之JSTL(下)
- APP启动页HTML,启动页.html
- PHP仿百度实现弹窗登录效果,js仿百度登录页实现拖动窗口效果
- 88相似标准形09——JJordan-Chevalley分解、幂零矩阵与幂零变换、幂零矩阵的判别、中国剩余定理、可换线性变换的性质
- Matlab矩阵的运算
- python与开源_Python与开源GIS
- 深富策略:指数横盘震荡整理 汽车整车表现亮眼
- 如何解决tomcat启动时出现 Server Tomcat v9.0 Server at localhost failed to start.
- token的颁发、保存与携带
- java微信公众号获取地理位置_Java微信公众平台开发之获取地理位置
- openvino下载模型
- VS 杂项文件全面解决方法
- 汇编语言复习题及详细答案2
- android10.0(Q) android11(R) 时区相关问题
- 智能仪器仪表行业数字化供应链管理系统:加速企业智慧供应链平台转型
热门文章
- Win10上注册OCX文件
- linux基础(十四)定时任务和管理系统的临时文件
- python读取excel数据为矩阵_用Python实现excel中“矩阵”式列表转“向量”式列表...
- 远离奸商-查看CPU信息是否被修改
- 宏文件下载_技能 | WPS如何启用宏功能,VBA组件安装
- windows10 企业版 ltsc系统的激活
- IT大公司面试流程与总结
- AUTOCAD——超级填充命令3
- estore简版商城思路
- 基于贝叶斯决策理论的分类方法