1、环境

2、文档

detr源码地址
detr论文地址

3、数据集

自定义coco数据集

4、模型

在github上面下载


链接:https://pan.baidu.com/s/1fmOYAOZ4yYx_rYquOS6Ycw
提取码:74l5

5、权重文件

生成自己所需要的权重文件

import torch
# 修改路径 预训练模型
pretrained_weights=torch.load('detr-r50.pth')
# 修改自己的类别
num_classes=3
pretrained_weights["model"]["class_embed.weight"].resize_(num_classes+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_classes+1)
torch.save(pretrained_weights,"detr_r50_%d.pth"%num_classes)

6、修改代码

main.py相应位置根据下图更改


model目录下面的detr.py文件相应位置更改类别 num_classes

7、训练模型

python main.py

8、测试模型

import argparse
import random
import time
from pathlib import Path
import numpy as np
import torch
from models import build_model
from PIL import Image
import os
import torchvision
from torchvision.ops.boxes import batched_nms
import cv2def get_args_parser():parser = argparse.ArgumentParser('Set transformer detector', add_help=False)parser.add_argument('--lr', default=1e-4, type=float)parser.add_argument('--lr_backbone', default=1e-5, type=float)parser.add_argument('--batch_size', default=2, type=int)parser.add_argument('--weight_decay', default=1e-4, type=float)parser.add_argument('--epochs', default=300, type=int)parser.add_argument('--lr_drop', default=200, type=int)parser.add_argument('--clip_max_norm', default=0.1, type=float,help='gradient clipping max norm')# Model parametersparser.add_argument('--frozen_weights', type=str, default=None,help="Path to the pretrained model. If set, only the mask head will be trained")# * Backbone# 如果设置为resnet101,后面的权重文件路径也需要修改一下parser.add_argument('--backbone', default='resnet50', type=str,help="Name of the convolutional backbone to use")parser.add_argument('--dilation', action='store_true',help="If true, we replace stride with dilation in the last convolutional block (DC5)")parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),help="Type of positional embedding to use on top of the image features")# * Transformerparser.add_argument('--enc_layers', default=6, type=int,help="Number of encoding layers in the transformer")parser.add_argument('--dec_layers', default=6, type=int,help="Number of decoding layers in the transformer")parser.add_argument('--dim_feedforward', default=2048, type=int,help="Intermediate size of the feedforward layers in the transformer blocks")parser.add_argument('--hidden_dim', default=256, type=int,help="Size of the embeddings (dimension of the transformer)")parser.add_argument('--dropout', default=0.1, type=float,help="Dropout applied in the transformer")parser.add_argument('--nheads', default=8, type=int,help="Number of attention heads inside the transformer's attentions")parser.add_argument('--num_queries', default=100, type=int,help="Number of query slots")parser.add_argument('--pre_norm', action='store_true')# * Segmentationparser.add_argument('--masks', action='store_true',help="Train segmentation head if the flag is provided")# Lossparser.add_argument('--no_aux_loss', dest='aux_loss', default='False',help="Disables auxiliary decoding losses (loss at each layer)")# * Matcherparser.add_argument('--set_cost_class', default=1, type=float,help="Class coefficient in the matching cost")parser.add_argument('--set_cost_bbox', default=5, type=float,help="L1 box coefficient in the matching cost")parser.add_argument('--set_cost_giou', default=2, type=float,help="giou box coefficient in the matching cost")# * Loss coefficientsparser.add_argument('--mask_loss_coef', default=1, type=float)parser.add_argument('--dice_loss_coef', default=1, type=float)parser.add_argument('--bbox_loss_coef', default=5, type=float)parser.add_argument('--giou_loss_coef', default=2, type=float)parser.add_argument('--eos_coef', default=0.1, type=float,help="Relative classification weight of the no-object class")# dataset parametersparser.add_argument('--dataset_file', default='coco')parser.add_argument('--coco_path', type=str, default="coco")parser.add_argument('--coco_panoptic_path', type=str)parser.add_argument('--remove_difficult', action='store_true')# 修改检测的图像路径parser.add_argument('--source_dir', default='/root/autodl-tmp/Deformable-DETR-main/data/data-labelme/test',help='path where to save, empty for no saving')# 修改检测结果保存路径parser.add_argument('--output_dir', default='result/',help='path where to save, empty for no saving')parser.add_argument('--device', default='cpu',help='device to use for training / testing')parser.add_argument('--seed', default=42, type=int)# 修改resnet50对应的权重文件parser.add_argument('--resume', default='output/checkpoint0299.pth',help='resume from checkpoint')parser.add_argument('--start_epoch', default=0, type=int, metavar='N',help='start epoch')parser.add_argument('--eval', default="True")parser.add_argument('--num_workers', default=2, type=int)# distributed training parametersparser.add_argument('--world_size', default=1, type=int,help='number of distributed processes')parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')return parserdef box_cxcywh_to_xyxy(x):x_c, y_c, w, h = x.unbind(1)b = [(x_c - 0.5 * w), (y_c - 0.5 * h),(x_c + 0.5 * w), (y_c + 0.5 * h)]return torch.stack(b, dim=1)def rescale_bboxes(out_bbox, size):img_w, img_h = sizeb = box_cxcywh_to_xyxy(out_bbox)b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)return bdef filter_boxes(scores, boxes, confidence=0.7, apply_nms=True, iou=0.5):keep = scores.max(-1).values > confidencescores, boxes = scores[keep], boxes[keep]if apply_nms:top_scores, labels = scores.max(-1)keep = batched_nms(boxes, top_scores, labels, iou)scores, boxes = scores[keep], boxes[keep]return scores, boxes# COCO classesCLASSES = ['green','puple','yellow']def plot_one_box(x, img, color=None, label=None, line_thickness=1):tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thicknesscolor = color or [random.randint(0, 255) for _ in range(3)]c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)if label:tf = max(tl - 1, 1)  # font thicknesst_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filledcv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)def main(args):print(args)device = torch.device(args.device)model, criterion, postprocessors = build_model(args)checkpoint = torch.load(args.resume, map_location='cpu')model.load_state_dict(checkpoint['model'],False)model.to(device)n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)print("parameters:", n_parameters)image_Totensor = torchvision.transforms.ToTensor()image_file_path = os.listdir(args.source_dir)for image_item in image_file_path:print("inference_image:", image_item)image_path = os.path.join(args.source_dir, image_item)image = Image.open(image_path)image_tensor = image_Totensor(image)image_tensor = torch.reshape(image_tensor,[-1, image_tensor.shape[0], image_tensor.shape[1], image_tensor.shape[2]])image_tensor = image_tensor.to(device)time1 = time.time()inference_result = model(image_tensor)time2 = time.time()print("inference_time:", time2 - time1)probas = inference_result['pred_logits'].softmax(-1)[0, :, :-1].cpu()bboxes_scaled = rescale_bboxes(inference_result['pred_boxes'][0,].cpu(),(image_tensor.shape[3], image_tensor.shape[2]))scores, boxes = filter_boxes(probas, bboxes_scaled)scores = scores.data.numpy()boxes = boxes.data.numpy()for i in range(boxes.shape[0]):class_id = scores[i].argmax()label = CLASSES[class_id]confidence = scores[i].max()text = f"{label}{confidence:.3f}"print(text)image = np.array(image)plot_one_box(boxes[i], image, label=text)# cv2.imshow("images", cv2.cvtColor(image,cv2.COLOR_BGR2RGB))# cv2.waitKey()image = Image.fromarray(image)image.save(os.path.join(args.output_dir, image_item))if __name__ == '__main__':parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])args = parser.parse_args()if args.output_dir:Path(args.output_dir).mkdir(parents=True, exist_ok=True)main(args)

9、结果

detr的测试对于小物体的检测不是很好,相比来说deformable detr的效果更好

目标检测——detr源码复现【 End-to-End Object Detection with Transformers】相关推荐

  1. 分析显著性目标检测--Global Context-Aware Progressive Aggregation Network for Salient Object Detection

    分析显著性目标检测--Global Context-Aware Progressive Aggregation Network for Salient Object Detection 引入 方法 网 ...

  2. 基于dota的目标检测(旋转框)论文阅读Oriented Object Detection in Aerial Images with Box Boundary-Aware Vectors

    基于dota的目标检测(旋转框)|论文阅读Oriented Object Detection in Aerial Images with Box Boundary-Aware Vectors 文章目录 ...

  3. 单目3D目标检测DEVIANT源码解析

    目前文档只包含outputs = model(inputs,coord_ranges,calibs,K=50,mode='test')之后,前向推理的源码解析,附带有测试程序 DEVIANT: Dep ...

  4. 基于Pytorch的从零开始的目标检测 | 附源码

    01. 引言 目标检测是计算机视觉中一个非常流行的任务,在这个任务中,给定一个图像,你预测图像中物体的包围盒(通常是矩形的) ,并且识别物体的类型.在这个图像中可能有多个对象,而且现在有各种先进的技术 ...

  5. 目标检测 最新源码大放送

    最新yolov3 代码 pytorch版,实时监测,准确率高,coco数据集map高达61.89% 源码地址: https://github.com/jacke121/YOLOv3_PyTorch

  6. 对抗学习用于目标检测--A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection

    A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection CVPR 2017 Caffe code : http ...

  7. 显著性目标检测之Shifting More Attention to Video Salient Object Detection

    Shifting More Attention to Video Salient Object Detection 文章目录 Shifting More Attention to Video Sali ...

  8. 三维目标检测论文:Deep Hough Voting for 3D Object Detection in Point Clouds

    3D目标检测框架VoteNet Charles R. Qi,Or Litany,何恺明,Leonidas J. Guibas等 当前主流的3D目标检测方法,很大层度上受2D检测器的影响.充分利用2D检 ...

  9. [目标检测知识蒸馏3] [AAAI22] Knowledge Distillation for Object Detection via Rank Mimicking and Prediction

    [AAAI22] Knowledge Distillation for Object Detection via Rank Mimicking and Prediction-guided Featur ...

最新文章

  1. C#拾遗系列(3):构造函数
  2. caffe python接口_ubuntu配置caffe的python接口pycaffe
  3. T-SQL查询进阶--基于列的逻辑表达式
  4. 自动化测试===unittest配套的HTMLTestRunner.py生成html报告源码
  5. PE知识复习之PE的导出表
  6. VTK:交叉点 PolyData 过滤器用法实战
  7. 【Tools】Linux下C和C++程序中内存泄露检测
  8. Chrome的console
  9. 关于人行acs对账不及时_以在线教育公司为例,如何做一款财务对账产品?
  10. dpdk18.11 收发包流程分析
  11. 程序员面试金典 - 面试题 16.01. 交换数字(位运算swap)
  12. dedecms 标签使用集锦
  13. 解决 transaction-manager Attribute transaction-manager is not allowed here
  14. Startlsback常见使用过程中的问题
  15. Flutter原理 flutter架构、flutter UI系统、BuildContext、Widget与Element、命中测试hitTest、flutter显示流程分析
  16. 苹果Mac电脑上fn键的妙用
  17. LightOJ 1135 - Count the Multiples of 3 线段树
  18. ES mapping
  19. js事件冒泡和传播详细解释
  20. 树莓派4b使用vncview连接报错:Cannot currently show the desktop

热门文章

  1. 海店湾:酒店成七夕新宠,浪漫之旅周边游说走就走!
  2. html 页面自动点击,JS脚本实现网页自动秒杀点击
  3. 基于C++的酒店管理系统
  4. iOS微信公众平台彻底关闭打赏功能
  5. 不用栈实现二叉树非递归中序遍历
  6. 一洽CEO:一洽在线客服生态链的使命
  7. Monaco Editor教程(二十):在编辑器的某个特定位置插入自定义的dom内容,图片,表单,表格,视频
  8. 授人予鱼不如授人予渔:零基础java学习路线分享
  9. 如何使用站群程序来批量建网站?
  10. 深度学习-参数和超参数介绍