摘要:本案例代码是FCOS论文复现的体验案例,此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

本文分享自华为云社区《通用物体检测算法 FCOS(目标检测/Pytorch)》,作者: HWCloudAI 。

FCOS:Fully Convolutional One-Stage Object Detection

本案例代码是FCOS论文复现的体验案例

此模型为FCOS论文中所提出算法在ModelArts + PyTorch框架下的实现。该算法使用MS-COCO公共数据集进行训练和评估。本代码支持FCOS + ResNet-101在MS-COCO数据集上完整的训练和测试流程

具体的算法介绍:AI Gallery_算法_模型_云市场-华为云

注意事项:

1.本案例使用框架: PyTorch1.0.0

2.本案例使用硬件: GPU

3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

1.数据和代码下载

import os
import moxing as mox
# 数据代码下载
mox.file.copy_parallel('obs://obs-aigallery-zc/algorithm/FCOS.zip','FCOS.zip')
# 解压缩
os.system('unzip  FCOS.zip -d ./')

2.模型训练

2.1依赖库安装及加载

"""
Basic training script for PyTorch
"""
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
import os
import argparse
import torch
import shutil
src_dir = './FCOS/'
os.chdir(src_dir)
os.system('pip install -r ./pip-requirements.txt')
os.system('python -m pip install ./trained_model/model/framework-2.0-cp36-cp36m-linux_x86_64.whl')
os.system('python setup.py build develop')
from framework.utils.env import setup_environment
from framework.config import cfg
from framework.data import make_data_loader
from framework.solver import make_lr_scheduler
from framework.solver import make_optimizer
from framework.engine.inference import inference
from framework.engine.trainer import do_train
from framework.modeling.detector import build_detection_model
from framework.utils.checkpoint import DetectronCheckpointer
from framework.utils.collect_env import collect_env_info
from framework.utils.comm import synchronize, \get_rank, is_pytorch_1_1_0_or_later
from framework.utils.logger import setup_logger
from framework.utils.miscellaneous import mkdir

2.2训练函数

def train(cfg, local_rank, distributed, new_iteration=False):model = build_detection_model(cfg)device = torch.device(cfg.MODEL.DEVICE)model.to(device)if cfg.MODEL.USE_SYNCBN:assert is_pytorch_1_1_0_or_later(), \"SyncBatchNorm is only available in pytorch >= 1.1.0"model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)optimizer = make_optimizer(cfg, model)scheduler = make_lr_scheduler(cfg, optimizer)if distributed:model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank,# this should be removed if we update BatchNorm statsbroadcast_buffers=False,)arguments = {}arguments["iteration"] = 0output_dir = cfg.OUTPUT_DIRsave_to_disk = get_rank() == 0checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk)print(cfg.MODEL.WEIGHT)extra_checkpoint_data = checkpointer.load_from_file(cfg.MODEL.WEIGHT)print(extra_checkpoint_data)arguments.update(extra_checkpoint_data)if new_iteration:arguments["iteration"] = 0data_loader = make_data_loader(cfg,is_train=True,is_distributed=distributed,start_iter=arguments["iteration"],)do_train(model,data_loader,optimizer,scheduler,checkpointer,device,arguments,)return model

2.3设置参数,开始训练

def main():parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")parser.add_argument('--train_url',default='./outputs',type=str,help='the path to save training outputs')parser.add_argument("--config-file",default="./trained_model/model/fcos_resnet_101_fpn_2x.yaml",metavar="FILE",help="path to config file",type=str,)parser.add_argument("--local_rank", type=int, default=0)parser.add_argument('--train_iterations', default=0, type=int)parser.add_argument('--warmup_iterations', default=500, type=int)parser.add_argument('--train_batch_size', default=8, type=int)parser.add_argument('--solver_lr', default=0.01, type=float)parser.add_argument('--decay_steps', default='120000,160000', type=str)parser.add_argument('--new_iteration',default=False, action='store_true')args, unknown = parser.parse_known_args()cfg.merge_from_file(args.config_file)# load the model trained on MS-COCOif args.train_iterations > 0:cfg.SOLVER.MAX_ITER = args.train_iterationsif args.warmup_iterations > 0:cfg.SOLVER.WARMUP_ITERS = args.warmup_iterationsif args.train_batch_size > 0:cfg.SOLVER.IMS_PER_BATCH = args.train_batch_sizeif args.solver_lr > 0:cfg.SOLVER.BASE_LR = args.solver_lrif len(args.decay_steps) > 0:steps = args.decay_steps.replace(' ', ',')steps = steps.replace(';', ',')steps = steps.replace(';', ',')steps = steps.replace(',', ',')steps = steps.split(',')steps = tuple([int(x) for x in steps])cfg.SOLVER.STEPS = stepscfg.freeze()num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1args.distributed = num_gpus > 1if args.distributed:torch.cuda.set_device(args.local_rank)torch.distributed.init_process_group(backend="nccl", init_method="env://")synchronize()output_dir = args.train_urlif output_dir:mkdir(output_dir)logger = setup_logger("framework", output_dir, get_rank())logger.info("Using {} GPUs".format(num_gpus))logger.info(args)logger.info("Loaded configuration file {}".format(args.config_file))train(cfg, args.local_rank, args.distributed, args.new_iteration)
if __name__ == "__main__":main()

3.模型测试

3.1预测函数

from framework.engine.predictor import Predictor
from PIL import Image,ImageDraw
import numpy as np
import matplotlib.pyplot as plt
def predict(img_path,model_path): config_file = "./trained_model/model/fcos_resnet_101_fpn_2x.yaml"cfg.merge_from_file(config_file)cfg.defrost()cfg.MODEL.WEIGHT = model_pathcfg.OUTPUT_DIR = Nonecfg.freeze()predictor = Predictor(cfg=cfg, min_image_size=800)src_img = Image.open(img_path)img = src_img.convert('RGB')img = np.array(img)img = img[:, :, ::-1]predictions = predictor.compute_prediction(img)top_predictions = predictor.select_top_predictions(predictions)bboxes = top_predictions.bbox.int().numpy().tolist()bboxes = [[x[1], x[0], x[3], x[2]] for x in bboxes]scores = top_predictions.get_field("scores").numpy().tolist()scores = [round(x, 4) for x in scores]labels = top_predictions.get_field("labels").numpy().tolist()labels = [predictor.CATEGORIES[x] for x in labels]draw = ImageDraw.Draw(src_img)for i,bbox in enumerate(bboxes):draw.text((bbox[1],bbox[0]),labels[i] + ':'+str(scores[i]),fill=(255,0,0))draw.rectangle([bbox[1],bbox[0],bbox[3],bbox[2]],fill=None,outline=(255,0,0))return src_img

3.2开始预测

if __name__ == "__main__":model_path = "./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth" # 训练得到的模型image_path = "./trained_model/model/demo_image.jpg" # 预测的图像img = predict(image_path,model_path)plt.figure(figsize=(10,10)) #设置窗口大小plt.imshow(img)plt.show()
2021-06-09 15:33:15,362 framework.utils.checkpoint INFO: Loading checkpoint from ./outputs/weights/fcos_resnet_101_fpn_2x/model_final.pth

点击关注,第一时间了解华为云新鲜技术~

FCOS论文复现:通用物体检测算法相关推荐

  1. 基于锚框与无需锚框的通用物体检测算法

    物体检测通常是指在图像中检测出物体出现的位置及对应的类别,是计算机视觉的根本问题,也是最基础的问题.它广泛应用于日常生活中,如浏览器的拍照识图.自动驾驶行人车辆检测.道路目标检测(人行道检测)及图像分 ...

  2. 中科院张士峰:基于深度学习的通用物体检测算法对比探索

    https://www.toutiao.com/a6674792954369933838/ 人工智能论坛如今浩如烟海,有硬货.有干货的讲座却百里挑一.由中国科学院大学主办,中国科学院大学学生会承办,读 ...

  3. 中科院自动化所博士带你入门CV物体检测算法

    物体检测通常是指在图像中检测出物体出现的位置及对应的类别,它是计算机视觉中的根本问题之一,同时也是最基础的问题,如图像分割.物体追踪.关键点检测等都依赖物体检测. 从应用来看,物体检测已广泛应用于大家 ...

  4. 详解通用物体检测算法:基于锚框与无需锚框

    物体检测通常是指在图像中检测出物体出现的位置及对应的类别,它是计算机视觉中的根本问题之一,同时也是最基础的问题,如图像分割.物体追踪.关键点检测等都依赖物体检测. 从应用来看,物体检测已广泛应用于大家 ...

  5. 论文阅读笔记 | 目标检测算法——FSAF算法

    如有错误,恳请指出 文章目录 1. Introduction 2. FSAF Module 2.1 Network Architecture 2.2 Ground-truth and Loss 2.2 ...

  6. CVPR 2023|UniDetector:7000类通用目标检测算法(港大清华)

    作者 | CV君  编辑 | 极市平台 点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心[目标检测]技术交流群 导读 论文中仅用了500 ...

  7. 激光雷达:最新趋势之基于RangeView的3D物体检测算法

    作者丨巫婆塔里的工程师@知乎 来源丨https://zhuanlan.zhihu.com/p/406674156 编辑丨3D视觉工坊 之前在LiDAR点云物体检测算法的综述中提到了四个发展阶段.在最开 ...

  8. ThunderNet:国防科大、旷视提出首个在ARM上实时运行的通用目标检测算法

    点击我爱计算机视觉标星,更快获取CVML新技术 今天跟大家分享一篇前天新出的论文<ThunderNet: Towards Real-time Generic Object Detection&g ...

  9. 论文阅读笔记 | 目标检测算法——SAPD算法

    如有错误,恳请指出. 文章目录 1. Introduction 2. Soft Anchor-Point Detector 2.1 Detection Formulation with Anchor ...

最新文章

  1. 中国AI开发者真实现状:写代码这条路,会走多久?
  2. Spring Boot 操作 Redis 的各种实现
  3. 大数据可以帮助企业获得资金吗?
  4. 第十次作业是同一个人
  5. 移动端目标识别(1)——使用TensorFlow Lite将tensorflow模型部署到移动端(ssd)之TensorFlow Lite简介...
  6. AES加密解密算法Java实现
  7. boost::system::generic_category相关的测试程序
  8. Kafka的优化建议
  9. SharpReader的效率:支持meme聚合
  10. python输入一个整数列表 列表元素为18_Python-18 (高级变量1--列表)
  11. 联想m100显示耗材_RTX3070显卡搭档高性能显示器,畅玩精美游戏大作!
  12. NSNumber的使用
  13. webpack-dev-server 设置反向代理解决跨域问题
  14. 【股价预测】基于matlab遗传算法优化BP神经网络预测股价【含Matlab源码 1250期】
  15. 记字符编码与转义符的纠缠
  16. _block 的使用 详细介绍
  17. 硬件基础之TTL、CMOS区分比较
  18. 随便说说,我回来啦~
  19. Php开发Dlp加密,DLP与文档透明加密 后防泄露时代之争
  20. 我该如何拯救你,我的考研?

热门文章

  1. [萌新必看]TomCat的WebAPP经常遇到的问题,诸如servelt404和SysTem.out.print无法在控制台输出等。
  2. 【GamePlay】两个ScrollView插件,Super ScrollView UIExtensions
  3. Distributed Deep Learning at the Edge-A Novel Proactive and Cooperative Caching Framework for Mobile
  4. 四:Jenkins日程表配置说明
  5. 业内人说:从中兴到联想,渠道平台商强势逻辑为何是硬伤?
  6. 计算机组成与体系结构——指令流水线
  7. 宇贸电商二期-用户模块开发(一)
  8. 微信小程序:webview头部状态栏被黑色区域填充问题
  9. Lattice原理及在通信中的应用 1 Lattice 基础
  10. 投资理财-如何判断公司价值