
  • 模型结构
    • backbone dla34
    • head
  • train
    • dataloader
  • eval
    • topk


backbone dla34

dla(Deep Layer Aggregation)
We introduce two structures for deep layer aggregation (DLA): iterative deep aggrega-
tion (IDA) and hierarchical deep aggregation (HDA).
Hierarchical deep aggregation merges blocks and stages in a tree to preserve and combine feature channels.
我们介绍两种结构深层聚合(DLA):迭代深层聚合 (IDA)和层次深度聚合(HDA)。
IDA focuses on fusing resolutions and scales while HDA focuses on merging features from all modules and channels.


class BasicBlock(nn.Module):def __init__(self, inplanes, planes, stride=1, dilation=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,stride=stride, padding=dilation,bias=False, dilation=dilation)self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=dilation,bias=False, dilation=dilation)self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)self.stride = stridedef forward(self, x, residual=None):if residual is None:residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += residualout = self.relu(out)return out


class Root(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, residual):super(Root, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, 1,stride=1, bias=False, padding=(kernel_size - 1) // 2)self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)self.relu = nn.ReLU(inplace=True)self.residual = residualdef forward(self, *x):children = xx = self.conv(torch.cat(x, 1))x = self.bn(x)if self.residual:x += children[0]x = self.relu(x)return xclass Tree(nn.Module):def __init__(self, levels, block, in_channels, out_channels, stride=1,level_root=False, root_dim=0, root_kernel_size=1,dilation=1, root_residual=False):super(Tree, self).__init__()if root_dim == 0:root_dim = 2 * out_channelsif level_root:root_dim += in_channelsif levels == 1:self.tree1 = block(in_channels, out_channels, stride,dilation=dilation)self.tree2 = block(out_channels, out_channels, 1,dilation=dilation)else:self.tree1 = Tree(levels - 1, block, in_channels, out_channels,stride, root_dim=0,root_kernel_size=root_kernel_size,dilation=dilation, root_residual=root_residual)self.tree2 = Tree(levels - 1, block, out_channels, out_channels,root_dim=root_dim + out_channels,root_kernel_size=root_kernel_size,dilation=dilation, root_residual=root_residual)if levels == 1:self.root = Root(root_dim, out_channels, root_kernel_size,root_residual)self.level_root = level_rootself.root_dim = root_dimself.downsample = Noneself.project = Noneself.levels = levelsif stride > 1:self.downsample = nn.MaxPool2d(stride, stride=stride)if in_channels != out_channels:self.project = nn.Sequential(nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1, bias=False),nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM))def forward(self, x, residual=None, children=None):children = [] if children is None else childrenbottom = self.downsample(x) if self.downsample else xresidual = self.project(bottom) if self.project else bottomif self.level_root:children.append(bottom)x1 = self.tree1(x, residual)if self.levels == 1:x2 = self.tree2(x1)x = self.root(x2, x1, *children)else:children.append(x1)x = self.tree2(x1, children=children)return x



class IDAUp(nn.Module):def __init__(self, o, channels, up_f):super(IDAUp, self).__init__()for i in range(1, len(channels)):c = channels[i]f = int(up_f[i])  proj = DeformConv(c, o)node = DeformConv(o, o)up = nn.ConvTranspose2d(o, o, f * 2, stride=f, padding=f // 2, output_padding=0,groups=o, bias=False)fill_up_weights(up)setattr(self, 'proj_' + str(i), proj)setattr(self, 'up_' + str(i), up)setattr(self, 'node_' + str(i), node)def forward(self, layers, startp, endp):for i in range(startp + 1, endp):upsample = getattr(self, 'up_' + str(i - startp))project = getattr(self, 'proj_' + str(i - startp))layers[i] = upsample(project(layers[i]))node = getattr(self, 'node_' + str(i - startp))layers[i] = node(layers[i] + layers[i - 1])class DLAUp(nn.Module):def __init__(self, startp, channels, scales, in_channels=None):super(DLAUp, self).__init__()self.startp = startpif in_channels is None:in_channels = channelsself.channels = channelschannels = list(channels)scales = np.array(scales, dtype=int)for i in range(len(channels) - 1):j = -i - 2setattr(self, 'ida_{}'.format(i),IDAUp(channels[j], in_channels[j:],scales[j:] // scales[j]))scales[j + 1:] = scales[j]in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]def forward(self, layers):out = [layers[-1]] # start with 32for i in range(len(layers) - self.startp - 1):ida = getattr(self, 'ida_{}'.format(i))ida(layers, len(layers) -i - 2, len(layers))out.insert(0, layers[-1])return out


class DLASeg(nn.Module):def __init__(self, base_name, pretrained, down_ratio, final_kernel,last_level, out_channel=0):super(DLASeg, self).__init__()assert down_ratio in [2, 4, 8, 16]self.first_level = int(np.log2(down_ratio))  # down_ratio=4self.last_level = last_levelself.base = globals()[base_name](pretrained=pretrained)channels = self.base.channelsscales = [2 ** i for i in range(len(channels[self.first_level:]))]self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)if out_channel == 0:out_channel = channels[self.first_level]self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level], [2 ** i for i in range(self.last_level - self.first_level)])def forward(self, x):x = self.base(x)x = self.dla_up(x)y = []for i in range(self.last_level - self.first_level):y.append(x[i].clone())self.ida_up(y, 0, len(y))x = y[-1]return x


class KeypointHead(nn.Module):def __init__(self, intermediate_channel, head_conv):super(KeypointHead, self).__init__()    self.hm = nn.Sequential(nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0))self.wh = nn.Sequential(nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))self.hps = nn.Sequential(nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, 34, kernel_size=1, stride=1, padding=0))                  self.reg = nn.Sequential(nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))        self.hm_hp = nn.Sequential(nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, 17, kernel_size=1, stride=1, padding=0))                                           self.hp_offset = nn.Sequential(nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),nn.ReLU(inplace=True),nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))                      self.init_weights()def forward(self, x):return [self.hm(x), self.wh(x), self.hps(x), self.reg(x), self.hm_hp(x), self.hp_offset(x)]




groundtruth根据输出大小进行仿射变换后得到新的bounding box坐标点,该bounding box计算目标的中心点为正样本点,其他位置都是负样本。


def __getitem__(self, index):# get img_id through indeximg_id = self.images[index]# get img_name by img_idfile_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']# get img_path by combining dataset_path and img_nameimg_path = os.path.join(self.img_dir, file_name)# get all annotation_ids through img_idann_ids = self.coco.getAnnIds(imgIds=[img_id])# get all annotations through ann_idsanns = self.coco.loadAnns(ids=ann_ids)# select annotions which category_id in self._valid_ids and is not crowd labeledanns = list(filter(lambda x:x['category_id'] in self._valid_ids and x['iscrowd']!= 1 , anns))# limit the numbers of objects in an imagenum_objs = min(len(anns), self.max_objs)# read the imageimg = cv2.imread(img_path)# get the property of attribute of this imgheight, width = img.shape[0], img.shape[1]# figure out the center of the image. shape=(x,y)c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)# the scale is defined as max edges = max(img.shape[0], img.shape[1]) * 1.0# rotate ?rot = 0flipped = Falseif self.split == 'train':if self.cfg.DATASET.RANDOM_CROP: #trues = s * np.random.choice(np.arange(0.6, 1.4, 0.1))w_border = self._get_border(128, img.shape[1])h_border = self._get_border(128, img.shape[0])c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)else:# random adjust center and scalesf = self.cfg.DATASET.SCALEcf = self.cfg.DATASET.SHIFTc[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)if np.random.random() < self.cfg.DATASET.AUG_ROT:rf = self.cfg.DATASET.ROTATErot = np.clip(np.random.randn()*rf, -rf*2, rf*2)if np.random.random() < self.cfg.DATASET.FLIP:flipped = Trueimg = img[:, ::-1, :]c[0] = width - c[0] - 1# calculate the array which make the original image to input formattrans_input = get_affine_transform(c, s, rot, [self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES])# make the original img to input format sizeinp = cv2.warpAffine(img, trans_input, (self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES),flags=cv2.INTER_LINEAR)# uniformizationinp = (inp.astype(np.float32) / 255.)if self.split == 'train' and not self.cfg.DATASET.NO_COLOR_AUG:color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)# normalizationinp = (inp - np.array(self.cfg.DATASET.MEAN).astype(np.float32)) / np.array(self.cfg.DATASET.STD).astype(np.float32)# adjust channels orderinp = inp.transpose(2, 0, 1)output_res = self.cfg.MODEL.OUTPUT_RESnum_joints = self.num_joints# calculate an array which make the original image to output sizetrans_output_rot = get_affine_transform(c, s, rot, [output_res, output_res])# calculate an array which make the original image to output size rather than input format to output formattrans_output = get_affine_transform(c, s, 0, [output_res, output_res])# calculate an array which make the original segmentation to output sizetrans_seg_output = get_affine_transform(c, s, 0, [output_res, output_res])# hm output targethm = np.zeros((self.num_classes, output_res, output_res), dtype=np.float32)# chekpoint heatmap output targethm_hp = np.zeros((num_joints, output_res, output_res), dtype=np.float32)dense_kps = np.zeros((num_joints, 2, output_res, output_res), dtype=np.float32)dense_kps_mask = np.zeros((num_joints, output_res, output_res), dtype=np.float32)# all objects sizewh = np.zeros((self.max_objs, 2), dtype=np.float32)# keypoints offset for center point location in ouput fomatkps = np.zeros((self.max_objs, num_joints * 2), dtype=np.float32)# offset between centerpoint and centerpoint_init in output formatreg = np.zeros((self.max_objs, 2), dtype=np.float32)# the index of all object center in ouput formatind = np.zeros((self.max_objs), dtype=np.int64)# mask for real objects,default 32 objects in an imagereg_mask = np.zeros((self.max_objs), dtype=np.uint8)# Keypoints mask for all real keypoints which is visibalkps_mask = np.zeros((self.max_objs, self.num_joints * 2), dtype=np.uint8)hp_offset = np.zeros((self.max_objs * num_joints, 2), dtype=np.float32)# keypoints index in ouputhp_ind = np.zeros((self.max_objs * num_joints), dtype=np.int64)# similar to kps_maskhp_mask = np.zeros((self.max_objs * num_joints), dtype=np.int64)# first draw gaussian for keypoints and then for the center pointdraw_gaussian = draw_msra_gaussian if self.cfg.LOSS.MSE_LOSS else \draw_umich_gaussiangt_det = []for k in range(num_objs):ann = anns[k]bbox = self._coco_box_to_bbox(ann['bbox'])cls_id = int(ann['category_id']) - 1pts = np.array(ann['keypoints'], np.float32).reshape(num_joints, 3)segment = self.coco.annToMask(ann)      if flipped:bbox[[0, 2]] = width - bbox[[2, 0]] - 1pts[:, 0] = width - pts[:, 0] - 1for e in self.flip_idx:pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy()segment = segment[:, ::-1]     bbox[:2] = affine_transform(bbox[:2], trans_output)bbox[2:] = affine_transform(bbox[2:], trans_output)bbox = np.clip(bbox, 0, output_res - 1)segment= cv2.warpAffine(segment, trans_seg_output,(output_res, output_res),flags=cv2.INTER_LINEAR)segment = segment.astype(np.float32)      h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]if (h > 0 and w > 0) or (rot != 0):# figure out gaussian radiusradius = gaussian_radius((math.ceil(h), math.ceil(w)))radius = self.cfg.hm_gauss if self.cfg.LOSS.MSE_LOSS else max(0, int(radius)) #后者# work out object center in output format and type is float32ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)# int type for center location of objectct_int = ct.astype(np.int32)# label w and h for the number k objcetwh[k] = 1. * w, 1. * h# calculate the index for the center of the k_th objectind[k] = ct_int[1] * output_res + ct_int[0]  # object loacation idx# calculate the diffience value for float center point and init center point to reduce discretization errorreg[k] = ct - ct_int       # offset between centerpoint and centerpoint_initreg_mask[k] = 1#keypoint     num_kpts = pts[:, 2].sum()if num_kpts == 0:hm[cls_id, ct_int[1], ct_int[0]] = 0.9999reg_mask[k] = 0hp_radius = gaussian_radius((math.ceil(h), math.ceil(w)))hp_radius = self.cfg.hm_gauss \if self.cfg.LOSS.MSE_LOSS else max(0, int(hp_radius)) for j in range(num_joints):if pts[j, 2] > 0:pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot)if pts[j, 0] >= 0 and pts[j, 0] < output_res and \pts[j, 1] >= 0 and pts[j, 1] < output_res:# offset between keypoints and centerpoint_initkps[k, j * 2: j * 2 + 2] = pts[j, :2] - ct_intkps_mask[k, j * 2: j * 2 + 2] = 1pt_int = pts[j, :2].astype(np.int32)# offset between keypoints and keypoints_inithp_offset[k * num_joints + j] = pts[j, :2] - pt_inthp_ind[k * num_joints + j] = pt_int[1] * output_res + pt_int[0]hp_mask[k * num_joints + j] = 1if self.cfg.LOSS.DENSE_HP:# must be before draw center hm gaussiandraw_dense_reg(dense_kps[j], hm[cls_id], ct_int, pts[j, :2] - ct_int, radius, is_offset=True)draw_gaussian(dense_kps_mask[j], ct_int, radius)draw_gaussian(hm_hp[j], pt_int, hp_radius)draw_gaussian(hm[cls_id], ct_int, radius)gt_det.append([ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2, 1] + pts[:, :2].reshape(num_joints * 2).tolist() + [cls_id])if rot != 0:hm = hm * 0 + 0.9999reg_mask *= 0kps_mask *= 0ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh,'hps': kps, 'hps_mask': kps_mask}if self.cfg.LOSS.DENSE_HP:dense_kps = dense_kps.reshape(num_joints * 2, output_res, output_res)dense_kps_mask = dense_kps_mask.reshape(num_joints, 1, output_res, output_res)dense_kps_mask = np.concatenate([dense_kps_mask, dense_kps_mask], axis=1)dense_kps_mask = dense_kps_mask.reshape(num_joints * 2, output_res, output_res)ret.update({'dense_hps': dense_kps, 'dense_hps_mask': dense_kps_mask})del ret['hps'], ret['hps_mask']if self.cfg.LOSS.REG_OFFSET:ret.update({'reg': reg})if self.cfg.LOSS.HM_HP:ret.update({'hm_hp': hm_hp})if self.cfg.LOSS.REG_HP_OFFSET:ret.update({'hp_offset': hp_offset, 'hp_ind': hp_ind, 'hp_mask': hp_mask})if self.cfg.DEBUG > 0 or not self.split == 'train':gt_det = np.array(gt_det, dtype=np.float32) if len(gt_det) > 0 else \np.zeros((1, 40), dtype=np.float32)meta = {'c': c, 's': s, 'gt_det': gt_det, 'img_id': img_id}ret['meta'] = metareturn ret


class MultiPoseLoss(torch.nn.Module):def __init__(self, cfg, local_rank):super(MultiPoseLoss, self).__init__()self.crit = FocalLoss() # hmself.crit_hm_hp = FocalLoss() # hmhpself.crit_kp = RegWeightedL1Loss() # keypoints offsetself.crit_reg = RegL1Loss()     # wh,reg ,hp_offset  self.cfg = cfgself.local_rank = local_rankdef forward(self, outputs, batch):cfg = self.cfghm_loss, wh_loss, off_loss= 0, 0, 0hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0hm, wh, hps, reg, hm_hp, hp_offset = outputsfor s in range(cfg.MODEL.NUM_STACKS):hm = _sigmoid(hm)      # (16,1,128,128)if cfg.LOSS.HM_HP and not cfg.LOSS.MSE_LOSS:hm_hp = _sigmoid(hm_hp)  # (16,17,128,128)# hm loss is calculate by focal losshm_loss += self.crit(hm, batch['hm']) / cfg.MODEL.NUM_STACKShp_loss += self.crit_kp(hps, batch['hps_mask'],   # hps:(16,34,128,128)batch['ind'], batch['hps']) / cfg.MODEL.NUM_STACKSif cfg.LOSS.WH_WEIGHT > 0:# use center index to find center location and find wh to calculate losswh_loss += self.crit_reg(wh, batch['reg_mask'],batch['ind'], batch['wh']) / cfg.MODEL.NUM_STACKSif cfg.LOSS.REG_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # trueoff_loss += self.crit_reg(reg, batch['reg_mask'],batch['ind'], batch['reg']) / cfg.MODEL.NUM_STACKSif cfg.LOSS.REG_HP_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true# use keypoints index to calculate keypoints discretization errorhp_offset_loss += self.crit_reg(hp_offset, batch['hp_mask'],batch['hp_ind'], batch['hp_offset']) / cfg.MODEL.NUM_STACKSif cfg.LOSS.HM_HP and cfg.LOSS.HM_HP_WEIGHT > 0:hm_hp_loss += self.crit_hm_hp(hm_hp, batch['hm_hp']) / cfg.MODEL.NUM_STACKSloss = cfg.LOSS.HM_WEIGHT * hm_loss + cfg.LOSS.WH_WEIGHT * wh_loss + \cfg.LOSS.OFF_WEIGHT * off_loss + cfg.LOSS.HP_WEIGHT * hp_loss + \cfg.LOSS.HM_HP_WEIGHT * hm_hp_loss + cfg.LOSS.OFF_WEIGHT * hp_offset_lossloss_stats = {'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,'wh_loss': wh_loss, 'off_loss': off_loss}return loss, loss_stats


def _neg_loss(pred, gt):''' Modified focal loss. Exactly the same as CornerNet.Runs faster and costs a little bit more memoryArguments:pred (batch x c x h x w)gt_regr (batch x c x h x w)'''pos_inds = gt.eq(1).float()neg_inds = gt.lt(1).float()neg_weights = torch.pow(1 - gt, 4)loss = 0pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_indsneg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_indsnum_pos  = pos_inds.float().sum()pos_loss = pos_loss.sum()neg_loss = neg_loss.sum()if num_pos == 0:loss = loss - neg_losselse:loss = loss - (pos_loss + neg_loss) / num_posreturn loss




def _topk(scores, K=40):batch, cat, height, width = scores.size()# select topk values of each categorytopk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)   # topk_inds => batch x cat x Ktopk_inds = topk_inds % (height * width)# calculate location for each categories using indstopk_ys   = (topk_inds / width).int().float()topk_xs   = (topk_inds % width).int().float()# select topk of all categoriestopk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) # topk_ind => batch x Ktopk_clses = (topk_ind / K).int()topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)

post process

def whole_body_decode(heat, wh, kps, seg_feat=None, seg=None, reg=None, hm_hp=None, hp_offset=None, K=100):batch, cat, height, width = heat.size()num_joints = kps.shape[1] // 2# perform nms on heatmapsheat = _nms(heat)scores, inds, clses, ys, xs = _topk(heat, K=K)kps = _transpose_and_gather_feat(kps, inds)kps = kps.view(batch, K, num_joints * 2)kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)if reg is not None:reg = _transpose_and_gather_feat(reg, inds)reg = reg.view(batch, K, 2)xs = xs.view(batch, K, 1) + reg[:, :, 0:1]ys = ys.view(batch, K, 1) + reg[:, :, 1:2]else:xs = xs.view(batch, K, 1) + 0.5ys = ys.view(batch, K, 1) + 0.5wh = _transpose_and_gather_feat(wh, inds)wh = wh.view(batch, K, 2)weight = _transpose_and_gather_feat(seg, inds)## you can write  (if weight.size(1)!=seg_feat.size(1): 3x3conv  else 1x1conv ) here to select seg conv.## for 3x3weight = weight.view([weight.size(1), -1, 3, 3])pred_seg = F.conv2d(seg_feat, weight, stride=1, padding=1)clses  = clses.view(batch, K, 1).float()scores = scores.view(batch, K, 1)bboxes = torch.cat([xs - wh[..., 0:1] / 2, ys - wh[..., 1:2] / 2,xs + wh[..., 0:1] / 2, ys + wh[..., 1:2] / 2], dim=2)if hm_hp is not None:hm_hp = _nms(hm_hp)thresh = 0.1kps = kps.view(batch, K, num_joints, 2).permute(0, 2, 1, 3).contiguous() #  b x K x 34 => b x J x K x 2# reg_kps represent duplicate (b,j,k,1,2) k times is diffierent from duplicate (b,j,1,k,2) k times like hm_kpsreg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)# find max scores of each joints(17) and its response index,ys,xshm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K# use hp_offset to make the position more preciseif hp_offset is not None:hp_offset = _transpose_and_gather_feat(hp_offset, hm_inds.view(batch, -1))hp_offset = hp_offset.view(batch, num_joints, K, 2)hm_xs = hm_xs + hp_offset[:, :, :, 0]hm_ys = hm_ys + hp_offset[:, :, :, 1]else:hm_xs = hm_xs + 0.5hm_ys = hm_ys + 0.5# use thresh to make maskmask = (hm_score > thresh).float()# use mask to select hm_score,hm_ys,hm_xs where hm_score >= threshhm_score = (1 - mask) * -1 + mask * hm_scorehm_ys = (1 - mask) * (-10000) + mask * hm_yshm_xs = (1 - mask) * (-10000) + mask * hm_xs# hm_kps represents the keypoints produced by joint heatmaphm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(2).expand(batch, num_joints, K, K, 2)# figure out the distance between hm_kps and reg_kpsdist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)#min_dist, min_ind = dist.min(dim=3) # b x J x Khm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1min_dist = min_dist.unsqueeze(-1)min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(batch, num_joints, K, 1, 2)hm_kps = hm_kps.gather(3, min_ind)hm_kps = hm_kps.view(batch, num_joints, K, 2)l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \(hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \(hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))mask = (mask > 0).float().expand(batch, num_joints, K, 2)kps = (1 - mask) * hm_kps + mask * kpskps = kps.permute(0, 2, 1, 3).contiguous().view(batch, K, num_joints * 2)detections = torch.cat([bboxes, scores, kps, torch.transpose(hm_score.squeeze(dim=3), 1, 2)], dim=2)return detections, pred_seg


