最近用p2pnet进行了复现,还对自己做的数据集进行了训练。在复现的过程中的一些注意事项在此做了记录。
代码地址:https://github.com/TencentYoutuResearch/CrowdCounting-P2PNet
论文地址:https://arxiv.org/abs/2107.12746

P2PNET网络图

测试:
在用p2pnet 测试自己的图像数据时,需要先下载vgg的训练好的模型,因为p2pnet是由vgg提取出的特征图上进行预测,计数的。原程序中没有给vgg的模型需要自己下载。
vgg_.py

"""
import torch
import torch.nn as nn__all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn','vgg19_bn', 'vgg19',
]model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth','vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth','vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth','vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth','vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}model_paths = {'vgg16_bn': 'weights/vgg16_bn-6c64b313.pth','vgg16': '/apdcephfs/private_changanwang/checkpoints/vgg16-397923af.pth',}

这里我们可以看到在model_paths中给了vgg16_bn和vgg16两种模型,我们只需要下载其中一种模型并把路径填到里面就行了。
run_test.py

def get_args_parser():parser = argparse.ArgumentParser('Set parameters for P2PNet evaluation', add_help=False)# * Backboneparser.add_argument('--backbone', default='vgg16_bn', type=str,help="name of the convolutional backbone to use")parser.add_argument('--row', default=2, type=int,help="row number of anchor points")parser.add_argument('--line', default=2, type=int,help="line number of anchor points")parser.add_argument('--output_dir', default='vis',#添加输出位置定位的路径文件夹help='path where to save')parser.add_argument('--weight_path', default='weights/SHTechA.pth',#添加p2pnet的训练好的模型help='path where the trained weights saved')parser.add_argument('--gpu_id', default=0, type=int, help='the gpu used for evaluation')return parserdef main(args, debug=False):os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)print(args)device = torch.device('cuda')# get the P2PNetmodel = build_model(args)# move to GPUmodel.to(device)# load trained modelif args.weight_path is not None:checkpoint = torch.load(args.weight_path, map_location='cpu')model.load_state_dict(checkpoint['model'])# convert to eval modemodel.eval()# create the pre-processing transformtransform = standard_transforms.Compose([standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])# set your image path hereimg_path = "./vis/demo1.jpg"#测试数据图片# load the imagesimg_raw = Image.open(img_path).convert('RGB')# round the sizewidth, height = img_raw.sizenew_width = width // 128 * 128new_height = height // 128 * 128img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS)# pre-proccessingimg = transform(img_raw)

这里只需要添加自己的输出路径、权重路径、和测试数据的路径就行了。
注意如果出现下面问题

这是版本不统一的问题,只需要将misc.py里的if语句注释掉就行了,如下图

训练:
如果要训练自己的数据集,那需要注意,这个程序进行训练需要的是数据图像和标注位置的txt文件,注意是txt文件,不是公共数据集的mat或csv文件。

import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import cv2
import glob
import scipy.io as ioclass SHHA(Dataset):def __init__(self, data_root, transform=None, train=False, patch=False, flip=False):self.root_path = data_rootself.train_lists = "shanghai_tech_part_a_train.list"#训练集路径self.eval_list = "shanghai_tech_part_a_test.list"#测试集路径# there may exist multiple list filesself.img_list_file = self.train_lists.split(',')if train:self.img_list_file = self.train_lists.split(',')else:self.img_list_file = self.eval_list.split(',')self.img_map = {}

这上面的训练集和测试集的文件路径,其实就将训练图片和其标注文件的路径放在一个txt文件中,最后把.txt后缀名改成.list。这个txt文件里的内容是:
图片1路径+空格+ 其对应的是标注文件txt文件的路径
图片2路径+空格+ 其对应的是标注文件txt文件的路径

然后改变train.py中的路径

# dataset parametersparser.add_argument('--dataset_file', default='SHHA')parser.add_argument('--data_root', default='./new_public_density_data',#你放这个程序文件夹的位置help='path where the dataset is')parser.add_argument('--output_dir', default='./log',#需要自己建立文件夹help='path where to save, empty for no saving')parser.add_argument('--checkpoints_dir', default='./ckpt',#需要自己建立文件夹,用来存放训练后的权重文件help='path where to save checkpoints, empty for no saving')parser.add_argument('--tensorboard_dir', default='./runs',help='path where to save, empty for no saving')parser.add_argument('--seed', default=42, type=int)parser.add_argument('--resume', default='', help='resume from checkpoint')parser.add_argument('--start_epoch', default=0, type=int, metavar='N',help='start epoch')parser.add_argument('--eval', action='store_true')parser.add_argument('--num_workers', default=8, type=int)parser.add_argument('--eval_freq', default=5, type=int,help='frequency of evaluation, default setting is evaluating in every 5 epoch')parser.add_argument('--gpu_id', default=0, type=int, help='the gpu used for training')return parserdef main(args):os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)# create the logging filerun_log_name = os.path.join(args.output_dir, 'run_log.txt')#这个run_log.txt需要自己建立with open(run_log_name, "w") as log_file:log_file.write('Eval Log %s\n' % time.strftime("%c"))

改变上面的程序文件夹存放的路径,添加训练好的模型存放的文件夹路径,和最后在自己建立的output_dir文件夹中建立run_log.txt文件,就可以进行训练了。
以上就是我们在复现的过程中需要调整的信息了。
如果有哪些小伙伴出现其他问题欢迎在评论区留言。

Crowd Counting P2PNet 复现相关推荐

  1. [CAN] [CVPR2019]:Context-Aware Crowd Counting论文+代码解读

    1.论文 论文链接:https://arxiv.org/pdf/1811.10452.pdf 代码链接:GitHub - weizheliu/Context-Aware-Crowd-Counting: ...

  2. 人群密度估计--Structured Inhomogeneous Density Map Learning for Crowd Counting

    Structured Inhomogeneous Density Map Learning for Crowd Counting https://arxiv.org/abs/1801.06642 针对 ...

  3. 人群密度估计--Leveraging Unlabeled Data for Crowd Counting by Learning to Rank

    Leveraging Unlabeled Data for Crowd Counting by Learning to Rank CVPR2018 https://github.com/xialeil ...

  4. 人群密度估计--Crowd Counting Via Scale-adaptive Convolutional Nerual Network

    Crowd Counting Via Scale-adaptive Convolutional Nerual Network https://arxiv.org/abs/1711.04433v2 Co ...

  5. 人群密度估计--Learning a perspective-embedded deconvolution network for crowd counting

    Learning a perspective-embedded deconvolution network for crowd counting 没有找到代码 本文在人群密度估计这个问题上的创新点: ...

  6. 越线人群计数--Crossing-line Crowd Counting with Two-phase Deep Neural Networks

    Crossing-line Crowd Counting with Two-phase Deep Neural Networks ECCV2016 人群计数有两种做法:1) region-of-int ...

  7. 快速人群密度估计--Multi-scale Convolutional Neural Networks for Crowd Counting

    Multi-scale Convolutional Neural Networks for Crowd Counting https://arxiv.org/abs/1702.02359 对于人群密度 ...

  8. 人群密度估计--Fully Convolutional Crowd Counting On Highly Congested Scenes

    Fully Convolutional Crowd Counting On Highly Congested Scenes The 12th International Conference on C ...

  9. 人群密度估计--CrowdNet: A Deep Convolutional Network for Dense Crowd Counting

    CrowdNet: A Deep Convolutional Network for Dense Crowd Counting published in the proceedings of ACM ...

  10. 人群密度估计--Spatiotemporal Modeling for Crowd Counting in Videos

    Spatiotemporal Modeling for Crowd Counting in Videos ICCV2017 针对视频人群密度估计问题,这里主要侧重视频中的 temporal infor ...

最新文章

  1. 腾讯微视AI新技术曝光:斩获VCR榜单第一
  2. python下载库报错_下载python中Crypto库报错:ModuleNotFoundError: No module named ‘Crypto’的解决...
  3. 如何解决分布式系统中的“幽灵复现”?
  4. P2714-四元组统计【数论,容斥】
  5. python3.7安装, 解决pip is configured with locations that require TLS/SSL问题
  6. usaco1.5.3(sprime)
  7. 钝化 会钝化 订单审批流程 码一会er
  8. 线上jvm 内存飙高排查
  9. 四足机器人足端轨迹规划--摆线
  10. python打印小星星案例详解_音乐案例 《小星星》
  11. linux 磁盘io技术3------libaio使用介绍
  12. 张艾迪(创始人):世界最高级创始人
  13. Linux:安装 telnet 命令
  14. 分享一个看起来挺酷眩的canvas做的粒子漩涡
  15. java通讯源码_GuQiu-JAVA做的局域网通讯源码
  16. php 栏目名称,PHPCMS V9调用栏目ID,栏目名称,父栏目,顶级父栏目
  17. 用Python爬虫技术怎么挣点小钱,这四种方法可行
  18. 用python做C语言的猜数字游戏,[Python3 练习] 007 简单的猜数字小游戏
  19. 【企业信息化】第5集 免费开源ERP: Odoo 16 inventory仓库管理系统 现代化线上仓库管理软件
  20. 【office考试】二级MS操作题试题解析-电子表格题

热门文章

  1. Docker 学习笔记(八)-- Dockerfile 构建CentOS 实战测试
  2. 嵌入式学习路径之单片机 | 月薪5个k到5个w的路径全在这了
  3. stm32采集脉冲信号_stm32用ETR采集外部脉冲个数出现二分频问题,请教哪里设置......
  4. cf950f Curfew
  5. winrar.msi_如何使WinRAR自动化以从setup.exe和MSI文件制作单个文件安装程序
  6. 算法导论第三章思考题
  7. 差动变压器的振动测量实验 思考题
  8. Linux第二章:6.Xftp安装教程、使用Xftp进行远程文件传输
  9. 【论文阅读】Advances and challenges in conversational recommender systems: A survey
  10. 电子护照阅读器|酒店机场高铁自助机录入系统