P2PNet 代码阅读笔记

一、主干网络


主干网络采用的是VGG16

class BackboneBase_VGG(nn.Module):def __init__(self, backbone: nn.Module, num_channels: int, name: str, return_interm_layers: bool):super().__init__()features = list(backbone.features.children())if return_interm_layers:if name == 'vgg16_bn':self.body1 = nn.Sequential(*features[:13])self.body2 = nn.Sequential(*features[13:23])self.body3 = nn.Sequential(*features[23:33])self.body4 = nn.Sequential(*features[33:43])else:self.body1 = nn.Sequential(*features[:9])self.body2 = nn.Sequential(*features[9:16])self.body3 = nn.Sequential(*features[16:23])self.body4 = nn.Sequential(*features[23:30])else:if name == 'vgg16_bn':self.body = nn.Sequential(*features[:44])  # 16x down-sampleelif name == 'vgg16':self.body = nn.Sequential(*features[:30])  # 16x down-sampleself.num_channels = num_channelsself.return_interm_layers = return_interm_layersdef forward(self, tensor_list):out = []if self.return_interm_layers:xs = tensor_listfor _, layer in enumerate([self.body1, self.body2, self.body3, self.body4]):xs = layer(xs)out.append(xs)else:xs = self.body(tensor_list)out.append(xs)return outclass Backbone_VGG(BackboneBase_VGG):"""ResNet backbone with frozen BatchNorm."""def __init__(self, name: str, return_interm_layers: bool):if name == 'vgg16_bn':backbone = models.vgg16_bn(pretrained=True)elif name == 'vgg16':backbone = models.vgg16(pretrained=True)num_channels = 256super().__init__(backbone, num_channels, name, return_interm_layers)

VGG16和VGG16bn的差别在于是否在每次卷积后加入了BatchNormalization(批归一化层)。
下图中D为VGG16的结构。

1.1 VGG16

Backbone_VGG16((body): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True))
)

1.2 VGG16_bn

Backbone_VGG16_bn((body): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(9): ReLU(inplace=True)(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(12): ReLU(inplace=True)(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(16): ReLU(inplace=True)(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(19): ReLU(inplace=True)(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(26): ReLU(inplace=True)(27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(29): ReLU(inplace=True)(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(32): ReLU(inplace=True)(33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(36): ReLU(inplace=True)(37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(39): ReLU(inplace=True)(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(42): ReLU(inplace=True)(43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
)

二、P2PNet网络

实际的网络架构和论文中还是不一样的,代码中加入了FPN层,如下图所示:

2.1 P2PNet

# the defenition of the P2PNet model
class P2PNet(nn.Module):def __init__(self, backbone, row=2, line=2):super().__init__()self.backbone = backbone #VGG16_bnself.num_classes = 2 #类别,两类:物体和背景# the number of all anchor points 预测点num_anchor_points = row * line#回归预测点self.regression = RegressionModel(num_features_in=256, num_anchor_points=num_anchor_points)#分类预测类别、置信度self.classification = ClassificationModel(num_features_in=256, \num_classes=self.num_classes, \num_anchor_points=num_anchor_points)self.anchor_points = AnchorPoints(pyramid_levels=[3,], row=row, line=line)self.fpn = Decoder(256, 512, 512)def forward(self, samples: NestedTensor):# get the backbone featuresfeatures = self.backbone(samples)# forward the feature pyramidfeatures_fpn = self.fpn([features[1], features[2], features[3]])batch_size = features[0].shape[0]# run the regression and classification branchregression = self.regression(features_fpn[1]) * 100 # 8xclassification = self.classification(features_fpn[1])anchor_points = self.anchor_points(samples).repeat(batch_size, 1, 1)# decode the points as predictionoutput_coord = regression + anchor_pointsoutput_class = classificationout = {'pred_logits': output_class, 'pred_points': output_coord}return out

2.2 回归层

# the network frmawork of the regression branch
class RegressionModel(nn.Module):def __init__(self, num_features_in, num_anchor_points=4, feature_size=256):super(RegressionModel, self).__init__()self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)self.act1 = nn.ReLU()self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)self.act2 = nn.ReLU()self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)self.act3 = nn.ReLU()self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)self.act4 = nn.ReLU()self.output = nn.Conv2d(feature_size, num_anchor_points * 2, kernel_size=3, padding=1)# sub-branch forwarddef forward(self, x):#两次卷积out = self.conv1(x)out = self.act1(out)out = self.conv2(out)out = self.act2(out)out = self.output(out)out = out.permute(0, 2, 3, 1) #将通道数放在最后一列#view之前用了transpose、permute等,需要用contiguous()来返回一个contiguouscopy#reshape成bathsize,...,2列return out.contiguous().view(out.shape[0], -1, 2)

2.2 分类层

# the network frmawork of the classification branch
class ClassificationModel(nn.Module):def __init__(self, num_features_in, num_anchor_points=4, num_classes=80, prior=0.01, feature_size=256):super(ClassificationModel, self).__init__()self.num_classes = num_classesself.num_anchor_points = num_anchor_pointsself.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)self.act1 = nn.ReLU()self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)self.act2 = nn.ReLU()self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)self.act3 = nn.ReLU()self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)self.act4 = nn.ReLU()self.output = nn.Conv2d(feature_size, num_anchor_points * num_classes, kernel_size=3, padding=1)self.output_act = nn.Sigmoid()# sub-branch forwarddef forward(self, x):out = self.conv1(x)out = self.act1(out)out = self.conv2(out)out = self.act2(out)out = self.output(out)out1 = out.permute(0, 2, 3, 1)batch_size, width, height, _ = out1.shapeout2 = out1.view(batch_size, width, height, self.num_anchor_points, self.num_classes)return out2.contiguous().view(x.shape[0], -1, self.num_classes)

2.3 计算锚点

# generate the reference points in grid layout 计算参考点
def generate_anchor_points(stride=16, row=3, line=3):row_step = stride / row  #row_step =8/2=4line_step = stride / line #line_step = 8/2=4 #shift_x =(1,2)-0.5*4-8/2=[-2,2]shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2#shift_x =(1,2)-0.5*4-8/2=[-2,2]shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2#meshgrid([-2,2]      shift_x  =([-2,2]    shift_y  =([-2,-2]#         [-2,2])  =             [-2,2])              [ 2, 2])shift_x, shift_y = np.meshgrid(shift_x, shift_y)#shift_x.ravel() = [-2,2,-2,2],shift_y.ravel() = [-2,-2,2,2]#anchor_points = np.vstack([-2,2,-2,2],[-2,-2,2,2]) = [[-2, 2,-2, 2]#                                                      [-2,-2, 2, 2]]#anchor_points.transpose()= [[-2,-2]#                            [-2, 2]#                            [ 2,-2]#                            [ 2, 2]#                            ]anchor_points = np.vstack((#x.ravel()展平并返回视图,元素会改变shift_x.ravel(), shift_y.ravel())).transpose() #转置成x,y的形似,原来是第一行为x,第二行为yreturn anchor_points# shift the meta-anchor to get an acnhor points
def shift(shape, stride, anchor_points):#每隔8个像素取一个锚点shift_x = (np.arange(0, shape[1]) + 0.5) * stride #[4, 12, 20,...]shift_y = (np.arange(0, shape[0]) + 0.5) * stride #[4, 12, 20, ...]#np.meshgrid([4, 12, 20,...]#            [4, 12, 20,...])=#shift_x = [[4, 12, 20,...]#            ...#           [4, 12, 20,...]]#shift_Y = [[4, 4, 4,...]#           [12,12,12,...]#           [..., ..., ..]]shift_x, shift_y = np.meshgrid(shift_x, shift_y)#shift_x.ravel()= [4,12,20,...,4,12,20,...]#shift_x.ravel()= [4,4,,...,12,,,...]#shifts = [[4,12,20,...,4,12,20,...]#          [4,4,,...,12,,,...]].transpose()#=[[4,4]# [4,12]# [4,20]# [4,..]# [..,.]# [12,4]# [12,12]# [....]# [....]# ]#相当于整幅图像中每隔8个像素点区一个锚点shifts = np.vstack((shift_x.ravel(), shift_y.ravel())).transpose()A = anchor_points.shape[0] #A=4K = shifts.shape[0] #K= img.shape[0]//8#all_anchor_points.shape=(1,4,2)+(K,1,2)=(K,4,2)all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))#all_anchor_points.shape=(K*4,2)图上的所有锚点all_anchor_points = all_anchor_points.reshape((K * A, 2))return all_anchor_points# this class generate all reference points on all pyramid levels
class AnchorPoints(nn.Module):def __init__(self, pyramid_levels=None, strides=None, row=3, line=3):super(AnchorPoints, self).__init__()if pyramid_levels is None:self.pyramid_levels = [3, 4, 5, 6, 7]else:self.pyramid_levels = pyramid_levelsif strides is None:self.strides = [2 ** x for x in self.pyramid_levels]self.row = rowself.line = linedef forward(self, image):image_shape = image.shape[2:]image_shape = np.array(image_shape)#这里将图像缩小8倍image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]#定义一个2列的空数组all_anchor_points = np.zeros((0, 2)).astype(np.float32)# get reference points for each level 循环每个fpn层for idx, p in enumerate(self.pyramid_levels):anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line) #计算锚点的偏移值shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points) #图上的所有锚点all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)#在第0维上进行扩展all_anchor_points = np.expand_dims(all_anchor_points, axis=0)# send reference points to device ,返回所有的锚点if torch.cuda.is_available():return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()else:return torch.from_numpy(all_anchor_points.astype(np.float32))

三、数据加载与预处理

3.1 Dataset

def build_dataset(args):if args.dataset_file == 'SHHA':from crowd_datasets.SHHA.loading_data import loading_datareturn loading_datareturn None
def loading_data(data_root):# the pre-proccssing transform 归一化操作transform = standard_transforms.Compose([standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])# create the training datasettrain_set = SHHA(data_root, train=True, transform=transform, patch=True, flip=True)# create the validation datasetval_set = SHHA(data_root, train=False, transform=transform)return train_set, val_set
class SHHA(Dataset):def __init__(self, data_root, transform=None, train=False, patch=False, flip=False):#数据集文件路径self.root_path = data_root#训练集和验证集的图片路径列表self.train_lists = "train.list"self.eval_list = "test.list"# there may exist multiple list files如果存在多个list,这里不存在self.img_list_file = self.train_lists.split(',')if train:self.img_list_file = self.train_lists.split(',')else:self.img_list_file = self.eval_list.split(',')self.img_map = {}self.img_list = []# loads the image/gt pairs 装栽图片和真实点for _, train_list in enumerate(self.img_list_file):train_list = train_list.strip() #删除多余的空格with open(os.path.join(self.root_path, train_list)) as fin:for line in fin:if len(line) < 2: #如果不是(图片路径 真实点文本路径)的格式则跳过 continueline = line.strip().split() #line[0]为图片路径,line[1]为真实点的坐标文本路径#图片对应的文本self.img_map[os.path.join(self.root_path, line[0].strip())] = \os.path.join(self.root_path, line[1].strip())#按图片名排序self.img_list = sorted(list(self.img_map.keys()))# number of samplesself.nSamples = len(self.img_list)self.transform = transformself.train = trainself.patch = patchself.flip = flipdef __len__(self):return self.nSamplesdef __getitem__(self, index):assert index <= len(self), 'index range error'img_path = self.img_list[index]gt_path = self.img_map[img_path]# load image and ground truth 返回cv2格式的图片和np数组的点坐标img, point = load_data((img_path, gt_path), self.train)# applu augumentation 图片归一化if self.transform is not None:img = self.transform(img)#若是训练if self.train:# data augmentation -> random scale 进行数据增强,随机选取一个规模scale_range = [0.7, 1.3] #随机因子的范围min_size = min(img.shape[1:])scale = random.uniform(*scale_range)# scale the image and points 对图像和点进行等比例缩放if scale * min_size > 128:img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0)point *= scale# random crop augumentaiton 对图片进行裁减 if self.train and self.patch:img, point = random_crop(img, point) #随机裁减出4个128*128*3的区域,并返回相对于该区域的点坐标(向量)for i, _ in enumerate(point):point[i] = torch.Tensor(point[i])# random flipping 有一半的概率随机水平翻转if random.random() > 0.5 and self.train and self.flip:# random flipimg = torch.Tensor(img[:, :, :, ::-1].copy())for i, _ in enumerate(point):point[i][:, 0] = 128 - point[i][:, 0]if not self.train:point = [point]img = torch.Tensor(img)# pack up related infostarget = [{} for i in range(len(point))]for i, _ in enumerate(point):target[i]['point'] = torch.Tensor(point[i])image_id = int(img_path.split('/')[-1].split('.')[0].split('_')[-1])image_id = torch.Tensor([image_id]).long()target[i]['image_id'] = image_idtarget[i]['labels'] = torch.ones([point[i].shape[0]]).long()return img, target #返回的是图片矩阵和一个target字典(包含point,imgid,label(1*N))

3.2 装载数据

def load_data(img_gt_path, train):img_path, gt_path = img_gt_path# load the imagesimg = cv2.imread(img_path)img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))  #用imread读取的图片是BGR颜色的,需要转换成RGB,不然会偏蓝# load ground truth pointspoints = []with open(gt_path) as f_label:for line in f_label:x = float(line.strip().split(' ')[0])y = float(line.strip().split(' ')[1])points.append([x, y])return img, np.array(points)

3.3 随机裁减增强

# random crop augumentation
def random_crop(img, den, num_patch=4):half_h = 128half_w = 128result_img = np.zeros([num_patch, img.shape[0], half_h, half_w]) result_den = []# crop num_patch for each image 随即裁减4个128*128*3的区域for i in range(num_patch):start_h = random.randint(0, img.size(1) - half_h) #返回 [0,img.size(1) - half_h] 之间的任意整数start_w = random.randint(0, img.size(2) - half_w)end_h = start_h + half_hend_w = start_w + half_w# copy the cropped rect 复制裁减区域result_img[i] = img[:, start_h:end_h, start_w:end_w]# copy the cropped points 复制裁减区域的点#idx 为一个True和false的数组idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h)# shift the corrdinates 更正坐标点位置record_den = den[idx]record_den[:, 0] -= start_wrecord_den[:, 1] -= start_h#添加到result中result_den.append(record_den)#返回裁减的4张图片的矩阵和坐标点return result_img, result_den

四、网络训练

4.1 训练模块

# create the P2PNet model
def build(args, training):# treats persons as a single classnum_classes = 1backbone = build_backbone(args)model = P2PNet(backbone, args.row, args.line)#如果是测试的话直接返回模型if not training: return modelweight_dict = {'loss_ce': 1, 'loss_points': args.point_loss_coef}losses = ['labels', 'points']#返回预测点和真实点匹配的索引matcher = build_matcher_crowd(args)#计算损失criterion = SetCriterion_Crowd(num_classes, \matcher=matcher, weight_dict=weight_dict, \eos_coef=args.eos_coef, losses=losses)#返回模型和损失return model, criterion

4.2 匹配器

class HungarianMatcher_Crowd(nn.Module):"""This class computes an assignment between the targets and the predictions of the networkFor efficiency reasons, the targets don't include the no_object. Because of this, in general,there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,while the others are un-matched (and thus treated as non-objects)."""# 训练时初始化权重 cost_class = 1, cost_point=0.05def __init__(self, cost_class: float = 1, cost_point: float = 3):"""Creates the matcherParams:cost_class: This is the relative weight of the foreground object 前景物体的权重cost_point: This is the relative weight of the L1 error of the points coordinates in the matching cost 匹配过程中点与点之间的L1误差"""super().__init__()self.cost_class = cost_classself.cost_point = cost_pointassert cost_class != 0 or cost_point != 0, "all costs cant be 0"@torch.no_grad()def forward(self, outputs, targets):""" Performs the matchingParams:outputs: This is a dict that contains at least these entries:"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits"points": Tensor of dim [batch_size, num_queries, 2] with the predicted point coordinatestargets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:"labels": Tensor of dim [num_target_points] (where num_target_points is the number of ground-truthobjects in the target) containing the class labels"points": Tensor of dim [num_target_points, 2] containing the target point coordinatesReturns:A list of size batch_size, containing tuples of (index_i, index_j) where:- index_i is the indices of the selected predictions (in order)- index_j is the indices of the corresponding selected targets (in order)For each batch element, it holds:len(index_i) = len(index_j) = min(num_queries, num_target_points)"""bs, num_queries = outputs["pred_logits"].shape[:2]# We flatten to compute the cost matrices in a batchout_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]out_points = outputs["pred_points"].flatten(0, 1)  # [batch_size * num_queries, 2]# Also concat the target labels and points# tgt_ids = torch.cat([v["labels"] for v in targets])tgt_ids = torch.cat([v["labels"] for v in targets])tgt_points = torch.cat([v["point"] for v in targets])# Compute the classification cost. Contrary to the loss, we don't use the NLL,# but approximate it in 1 - proba[target class].# The 1 is a constant that doesn't change the matching, it can be ommitted.cost_class = -out_prob[:, tgt_ids]# Compute the L2 cost between point#torch.cdist(x1, x2, p=2.0, compute_mode=‘use_mm_for_euclid_dist_if_necessary’)计算两组输入的每对点之间的距离#x1 (Tensor) – input tensor of shape B×P×M .#x2 (Tensor) – input tensor of shape B×R×M .#output (Tensor) – will have shape B×P×R#p=2 means L2 losscost_point = torch.cdist(out_points, tgt_points, p=2)# Compute the giou cost between point# Final cost matrix 计算成本矩阵用于匈牙利算法的匹配#C=0.05*cost_point + 1 * cost_class C = self.cost_point * cost_point + self.cost_class * cost_classC = C.view(bs, num_queries, -1).cpu() #view成[bs,num_queries,-1]的格式#获取每一个真实标签的数量sizes = [len(v["point"]) for v in targets]#匈牙利算法#scipy.optimize.linear_sum_assignment(cost_matrix,maximize=False) 解决线性和分配问题,这里有不同图像的成本矩阵indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]#最后返回预测点和对应的真实点的索引值return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]def build_matcher_crowd(args):return HungarianMatcher_Crowd(cost_class=args.set_cost_class, cost_point=args.set_cost_point)

4.3 损失函数

class SetCriterion_Crowd(nn.Module):def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):""" Create the criterion.Parameters:num_classes: number of object categories, omitting the special no-object categorymatcher: module able to compute a matching between targets and proposalsweight_dict: dict containing as key the names of the losses and as values their relative weight.eos_coef: relative classification weight applied to the no-object categorylosses: list of all the losses to be applied. See get_loss for list of available losses."""#    num_classes = 1#    weight_dict = {'loss_ce': 1, 'loss_points': args.point_loss_coef}#    eos_coef = 0.5#    losses = ['labels', 'points']super().__init__()self.num_classes = num_classesself.matcher = matcherself.weight_dict = weight_dictself.eos_coef = eos_coefself.losses = losses#empty_weight = [1#                1]empty_weight = torch.ones(self.num_classes + 1)#empty_weight = [0.5#                1]empty_weight[0] = self.eos_coefself.register_buffer('empty_weight', empty_weight)def loss_labels(self, outputs, targets, indices, num_points):"""Classification loss (NLL)targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]"""assert 'pred_logits' in outputssrc_logits = outputs['pred_logits']idx = self._get_src_permutation_idx(indices)target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])target_classes = torch.full(src_logits.shape[:2], 0,dtype=torch.int64, device=src_logits.device)target_classes[idx] = target_classes_oloss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)losses = {'loss_ce': loss_ce}return lossesdef loss_points(self, outputs, targets, indices, num_points):assert 'pred_points' in outputsidx = self._get_src_permutation_idx(indices)src_points = outputs['pred_points'][idx]target_points = torch.cat([t['point'][i] for t, (_, i) in zip(targets, indices)], dim=0)loss_bbox = F.mse_loss(src_points, target_points, reduction='none')losses = {}losses['loss_point'] = loss_bbox.sum() / num_pointsreturn lossesdef _get_src_permutation_idx(self, indices):# permute predictions following indicesbatch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])src_idx = torch.cat([src for (src, _) in indices])return batch_idx, src_idxdef _get_tgt_permutation_idx(self, indices):# permute targets following indicesbatch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])tgt_idx = torch.cat([tgt for (_, tgt) in indices])return batch_idx, tgt_idxdef get_loss(self, loss, outputs, targets, indices, num_points, **kwargs):loss_map = {'labels': self.loss_labels,'points': self.loss_points,}assert loss in loss_map, f'do you really want to compute {loss} loss?'return loss_map[loss](outputs, targets, indices, num_points, **kwargs)def forward(self, outputs, targets):""" This performs the loss computation.Parameters:outputs: dict of tensors, see the output specification of the model for the formattargets: list of dicts, such that len(targets) == batch_size.The expected keys in each dict depends on the losses applied, see each loss' doc"""#得到outputoutput1 = {'pred_logits': outputs['pred_logits'], 'pred_points': outputs['pred_points']}#匈牙利算法匹配的结果索引indices1 = self.matcher(output1, targets)#真实点的数量num_points = sum(len(t["labels"]) for t in targets)#转为tensornum_points = torch.as_tensor([num_points], dtype=torch.float, device=next(iter(output1.values())).device)#单机多卡的训练if is_dist_avail_and_initialized():torch.distributed.all_reduce(num_points)num_boxes = torch.clamp(num_points / get_world_size(), min=1).item()losses = {}#更新lossfor loss in self.losses:losses.update(self.get_loss(loss, output1, targets, indices1, num_boxes))return losses

五、预测

def get_args_parser():parser = argparse.ArgumentParser('Set parameters for P2PNet evaluation', add_help=False)# * Backboneparser.add_argument('--backbone', default='vgg16_bn', type=str,help="name of the convolutional backbone to use")parser.add_argument('--row', default=2, type=int,help="row number of anchor points")parser.add_argument('--line', default=2, type=int,help="line number of anchor points")parser.add_argument('--output_dir', default='output/',help='path where to save')parser.add_argument('--weight_path', default='weights/best_mae.pth',help='path where the trained weights saved')parser.add_argument('--gpu_id', default=-1, type=int, help='the gpu used for evaluation')return parserdef main(args, debug=False):os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)print(args)device = torch.device('cpu')# get the P2PNetmodel = build_model(args)# move to GPUmodel.to(device)# load trained modelif args.weight_path is not None:checkpoint = torch.load(args.weight_path, map_location='cpu')model.load_state_dict(checkpoint['model'])# convert to eval modemodel.eval()# create the pre-processing transformtransform = standard_transforms.Compose([standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# set your image path hereimg_path = "vis/test1/"imgsfile = os.listdir(img_path)os.makedirs(args.output_dir+img_path,exist_ok=True)for imgname in imgsfile:# load the images#img_raw = Image.open(img_path).convert('RGB')start =time.time()img_raw = Image.open(img_path+imgname).convert('RGB')H,W = img_raw.size"""if H>1000 and H <=1500:if W >1000 and W<=1500:img_raw = img_raw.resize((int(H/2), int(W/2)), Image.ANTIALIAS)elif H >=2000 and H <=3500:if W >=2000 and W<=3500:img_raw = img_raw.resize((int(H/4), int(W/4)), Image.ANTIALIAS)elif H >3500 or W >3500:img_raw = img_raw.resize((int(H/4), int(W/4)), Image.ANTIALIAS)"""#img_raw = img_raw.resize((1024, 1024), Image.ANTIALIAS)#img_raw = Image.open(img_path).convert('RGB')#print(img_raw.size)# round the size将图片的宽高resize成128的倍数width, height = img_raw.sizenew_width = width // 128 * 128new_height = height // 128 * 128img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS)#print(img_raw.size)# pre-proccessing预处理、归一化img = transform(img_raw)samples = torch.Tensor(img).unsqueeze(0)samples = samples.to(device)# run inferenceoutputs = model(samples)outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]outputs_points = outputs['pred_points'][0]threshold = 0.5 #置信度# filter the predictionspoints = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()predict_cnt = int((outputs_scores > threshold).sum()) #预测数量outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]outputs_points = outputs['pred_points'][0]# draw the predictionssize = 2img_to_draw = cv2.cvtColor(np.array(img_raw), cv2.COLOR_RGB2BGR)for p in points:img_to_draw = cv2.circle(img_to_draw, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)# save the visualized imagecv2.imwrite(os.path.join(args.output_dir,img_path+imgname+'_pred{}.jpg'.format(predict_cnt)), img_to_draw)end = time.time()print(imgname+" inference time:",end - start)if __name__ == '__main__':parser = argparse.ArgumentParser('P2PNet evaluation script', parents=[get_args_parser()])args = parser.parse_args()main(args)

P2PNet(代码阅读笔记)相关推荐

  1. [置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  2. linux 协议栈 位置,[置顶] Linux协议栈代码阅读笔记(一)

    Linux协议栈代码阅读笔记(一) (基于linux-2.6.21.7) (一)用户态通过诸如下面的C库函数访问协议栈服务 int socket(int domain, int type, int p ...

  3. BNN Pytorch代码阅读笔记

    BNN Pytorch代码阅读笔记 这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家! 论文链接: ...

  4. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练

    系列目录: 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(一)--数据 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(二)-- 介绍及分词 菜鸟笔记-DuReader阅读理解基线模 ...

  5. leveldb代码阅读笔记(一)

    leveldb代码阅读笔记 above all leveldb是一个单机的键值存储的内存数据库,其内部使用了 LSM tree 作为底层存储结构,支持多版本数据控制,代码设计巧妙且简洁高效,十分值得作 ...

  6. C++ Primer Plus 6th代码阅读笔记

    C++ Primer Plus 6th代码阅读笔记 第一章没什么代码 第二章代码 carrots.cpp : cout 可以拼接输出,cin.get()接受输入 convert.cpp 函数原型放在主 ...

  7. [原创]fetchmail代码阅读笔记---ESMTP的认证方式

    fetchmail代码阅读笔记---ESMTP的认证方式 作者: 默难 ( monnand@gmail.com ) 0    引言 fetchmail是Eric S. Raymond组织编写的一款全功 ...

  8. CNN去马赛克代码阅读笔记

    有的博客链接是之前几周写好的草稿,最近整理的时候才发布的 CNN去马赛克论文及代码下载地址 有torch,minimal torch和caffe三种版本 关于minimal torch版所做的努力,以 ...

  9. ORB-SLAM2代码阅读笔记(五):Tracking线程3——Track函数中单目相机初始化

    Table of Contents 1.特征点匹配相关理论简介 2.ORB-SLAM2中特征匹配代码分析 (1)Tracking线程中的状态机 (2)单目相机初始化函数MonocularInitial ...

最新文章

  1. 深度学习有哪些trick?
  2. 计算机原理与基础 —— (皇帝身边的小太监----寄存器)
  3. android源码分析之JNI调用与回调
  4. Mac资讯:macos big sur正式版推送 macOS 11 Big Sur有哪些不兼容的软件?
  5. 密码学加解密实训(墨者学院摩斯密码第2题)
  6. 史上最简便的可以直接用的登录验证码攻略(前后端都有)
  7. Android以太网卡配置启动流程和双网卡同时支持的实现
  8. 路由的导航守卫过渡动效transtion导航守卫 路由懒加载 路由元信息 @stage3---wee2--day7
  9. 分享一份软件测试项目(Python项目)
  10. 数论的基础入门(初读数论概论有感)(acm知识储备)
  11. 不同dom的blur事件和click事件发生冲突
  12. 球半足球分析,巴西甲:布拉干RB VS 博塔弗戈 7月5日
  13. 790-C语言的数组元素下标为何从0开始?
  14. 毕业设计之“真心话大冒险”小程序
  15. 双系统装完只能u盘启动_u盘装双系统开机没有系统选择界面怎么解决
  16. 振动力学——1.单自由度系统自由振动
  17. 【AutoSAR】 CP 和 AP
  18. python数据结构编程题_生信编程实战第5题(python)
  19. Android-软键盘一招搞定(实践篇)
  20. 前端复习--图片加载

热门文章

  1. 底量超顶量超级大黑马指标源码_底量超顶量抓取黑马股的实战技法
  2. LaTeX---行距设置
  3. 教你看别人的QQ密码
  4. GX Works2 SFC编程基础
  5. Ubuntu QT 5.9.0 安装
  6. 当下电商究竟都适合从事那些项目,如何居家赚钱?
  7. 【Matlab人脸识别】KL变换人脸识别【含GUI源码 859期】
  8. CakePHP FAQ(常见问题)整理
  9. 【算法基础】TOPSIS法
  10. vscode无法跳转到函数定义