代码出处

吃水不忘打井人,分析github上的基于keras的实现:

xuannianz/keras-CenterNet​github.com

代码主体结构

模型训练的主函数流程如下所示,该流程也是使用keras的较为标准的流程。其中代码篇幅较大的是数据准备的部分,通常的代码也亦如此。下面按照不同的部分分别进行说明。

create_generators 数据集准备

该代码支持Pascal VOC格式、COCO格式以及CSV格式。keras中有三个函数可以用来进行模型的训练:分别是fit,fit_generator和train_on_batch。

fit(train_x, train_y, batchsize, epochs)

在使用fit进行模型训练时,通常假设整个训练集都可以放入RAM,并且没有数据增强(即不需要keras生成器)。常用于简单小型的数据集训练。

fit_generator;常常使用的模型训练函数

fit_generator适用于大数据集无法直接全部放入内存中,以及标注数据较少需要使用数据增强来增加训练模型的泛化能力。fit_generator需要传入一个数据生成器,数据生成器可以每次动态的生成一个batchsize的训练数据,通常我们也将数据增强放入数据生成器中,这样便可以动态的生成增强后的数据。在使用fit_generator时,需要传入steps_per_epoch的值,而fit函数则不需要,这是因为fit函数的steps_per_epoch默认等于总的训练数据/batchsize,而对于fit_generator来说,如果采用了数据增强,则可以产生无限的batchsize训练数据,因此需要指定该参数。

By the way,数据生成器可以使用keras的API或者直接自己手码python的代码,因为其本质上也就是python的函数。

train_on_batch(batchX, batchY)

train_on_batch用于需要对训练迭代进行精细控制,给其传入一批数据即可(数据大小任意),不需要提供batchsize的大小。通常很少使用该函数进行模型训练。

  • 本算法的实现过程就是采用的fit_generator进行的模型训练。因此需要为其构建数据生成器。common.py文件:class Generator(keras.utils.Sequence)构建数据生成器的基类,咱们先说道说道keras.utils.Sequence这个类。
keras.utils.Sequence:这个基类通常应用于数据集生成一个数据序列。使用时需构建一个python类继承自该
基类,并必须实现__len__和__getitem__两个函数,如果要在每个epoch间修改数据集则需要实现on_epoch_end
方法。
NOTE:特别注意,__getitem__要返回一个完整的batchsize数据,__len__统计的也是有多少个batch

Generator类可以当成一个抽象基类,其中主要实现的是batch的划分、数据增强的处理、以及标注数据的转换(将bounding box的标注形式转换成高斯分布的标注)。而真正使用的数据集的生成器如下所示。主要按照不同的数据集生成的类,并均都继承于Generator抽象类,这里区分不同的数据集主要为了能方便区分其不同的数据标注格式,使用起来更为方便。主要是load_annotations()和load_image()函数的实现。至此数据生成器便构建完成了。

class PascalVocGenerator(Generator)
class CocoGenerator(Generator)

centernet网络构建

算法实现采用的Resnet50作为网络的backbone,采用下述引用网络。网络构建这里相对就比较简单了,取出Resnet的C5,先添加了一层dropout,然后进行了上采样,然后分别构建网络head,主要有三支:中心点预测、中心点偏移值预测以及bouding box的size预测。

from keras.applications.resnet50 import ResNet50

最后构建model,使用keras的Lambda层构建loss,作为model的output

loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input])
model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=[loss_])

预训练模型权重加载

keras的模型加载可以使用load_weights来实现,其模型加载可以按照模型结构加载,此时by_name需设置为False。否则将按照网络层的名字来加载,此时通常将skip_mismatch也设置成True,即仅加载名字相同的层,其他名字不同的层直接跳过。因此可以利用这个特性,对已训练好的网络局部进行修改,然后再加载之前训练好的模型,方便进行模型的调优。

model.load_weights(args.snapshot, by_name=True, skip_mismatch=True)

模型配置

其中loss参数的传递有几种形式。

  • 目标函数/损失函数的字符串,比如keras内置的一些损失函数
  • 目标函数/损失函数,通常为自定义的损失函数
  • 将目标函数/损失函数定义成model的一个层,类似本代码的实现。本代码实现时,因为直接把loss作为model的输出,因此输入y_true和y_pred,实际使用y_pred即输出loss,对其进行优化。
model.compile(optimizer=Adam(lr=1e-3), loss={'centernet_loss': lambda y_true, y_pred: y_pred})def compile(self, optimizer,loss=None,metrics=None,loss_weights=None,sample_weight_mode=None,weighted_metrics=None,target_tensors=None,**kwargs):"""Configures the model for training.# Argumentsoptimizer: String (name of optimizer) or optimizer instance.See [optimizers](/optimizers).loss: String (name of objective function) or objective function.See [losses](/losses).If the model has multiple outputs, you can use a different losson each output by passing a dictionary or a list of losses.The loss value that will be minimized by the modelwill then be the sum of all individual losses.metrics: List of metrics to be evaluated by the modelduring training and testing.Typically you will use `metrics=['accuracy']`.To specify different metrics for different outputs of amulti-output model, you could also pass a dictionary,such as `metrics={'output_a': 'accuracy'}`.loss_weights: Optional list or dictionary specifying scalarcoefficients (Python floats) to weight the loss contributionsof different model outputs.The loss value that will be minimized by the modelwill then be the *weighted sum* of all individual losses,weighted by the `loss_weights` coefficients.If a list, it is expected to have a 1:1 mappingto the model's outputs. If a dict, it is expected to mapoutput names (strings) to scalar coefficients.sample_weight_mode: If you need to do timestep-wisesample weighting (2D weights), set this to `"temporal"`.`None` defaults to sample-wise weights (1D).If the model has multiple outputs, you can use a different`sample_weight_mode` on each output by passing adictionary or a list of modes.weighted_metrics: List of metrics to be evaluated and weightedby sample_weight or class_weight during training and testing.target_tensors: By default, Keras will create placeholders for themodel's target, which will be fed with the target data duringtraining. If instead you would like to use your owntarget tensors (in turn, Keras will not expect externalNumpy data for these targets at training time), youcan specify them via the `target_tensors` argument. It can bea single tensor (for a single-output model), a list of tensors,or a dict mapping output names to target tensors.**kwargs: When using the Theano/CNTK backends, these argumentsare passed into `K.function`.When using the TensorFlow backend,these arguments are passed into `tf.Session.run`.

keras优化算法_目标检测算法 - CenterNet - 代码分析相关推荐

  1. 基于haar特征的adaboost算法_目标检测算法介绍

    什么是目标检测 目标检测是指从图像中找出目标,包括检测和识别两个过程,现实中由于环境的复杂性以及各类物体的形状.外观以及光照,遮挡等因素的干扰,所以目标检测一直也是计算机视觉最常见的挑战之一. 目标检 ...

  2. 路面裂痕检测YOLO算法、目标检测算法实现地面裂缝检测

    道路裂纹检测YOLO算法,目标检测,目标识别,裂纹检测 路面裂痕检测YOLO算法.目标检测算法实现地面裂缝检测 车头定位 交通标志识别 车道线识别 自己标注数据,训练模型,效果很好4360063193 ...

  3. python ssd目标检测_目标检测算法之SSD的数据增强策略

    前言 这篇文章是对前面<目标检测算法之SSD代码解析>,推文地址如下:点这里的补充.主要介绍SSD的数据增强策略,把这篇文章和代码解析的文章放在一起学最好不过啦.本节解析的仍然是上篇SSD ...

  4. 找不到匹配的key exchange算法_目标检测--匹配策略

    CVPR2020中的文章ATSS揭露到anchor-based和anchor-free的目标检测算法之间的效果差异原因是由于正负样本的选择造成的.而在目标检测算法中正负样本的选择是由gt与anchor ...

  5. ap 目标检测算法map_目标检测算法的评估指标:mAP定义及计算方式

    前面依次介绍了: 本节介绍目标检测算法的评估指标:mAP定义及计算方式 mAP:mean Average Precision,平均精度均值,即AP(Average Precision)的平均值,它是目 ...

  6. 2018目标检测最新算法+经典目标检测算法

    干货 CVPR2018的目标检测总结(论文+开源代码)https://blog.csdn.net/wfei101/article/details/80861681 目标检测算法集合(论文+开源代码)h ...

  7. 点在不规则图形内算法python_目标检测算法中规则矩形和不规则四边形IOU的Python实现...

    交并比(Intersection-over-Union,IoU),目标检测中使用的一个概念,我们在进行目标检测算法测试时,重要的指标,是产生的预测框(candidate bound)与标记框(grou ...

  8. 目标检测算法 2020_One-stage目标检测算法综述

    yolo-v1: YOLO 就是使用回归这种做法的典型算法. 首先将图片 Resize 到固定尺寸,然后通过一套卷积神经网络,最后接上 FC 直接输出结果,这就他们整个网络的基本结构. 更具体地做法, ...

  9. 目标检测算法的大体框架-------backbone、head、neck

        在基于深度学习算法的目标检测算法其实大体上都是由三部分组成的,即backbone.head.neck.整个算法的设计流程基本都是:输入->backbone->neck->he ...

最新文章

  1. R语言使用treemap包中的treemap函数可视化treemap图:treemap将分层数据显示为一组嵌套矩形、自定义设置treemap图的调色板、自定义设置treemap标题字体的大小
  2. 剑指offer:平衡二叉树
  3. active英语怎么读音_必须收藏!英语48个音标发音(附详细图解+视频教程)
  4. PIC模拟从入门到熟练系列之组会PPT20210906《Note of PIC》
  5. 使用final关键字修饰一个变量时,是引用不能变,还是引用的对象不能变?
  6. linux的基础知识——全局变量异步I/O
  7. 苏宁6亿会员是如何做到精确快速分析的?
  8. Tensorflow学习—— 预创建的 Estimator
  9. android 时间戳 转日期格式,在Android中转换为简单日期格式或Unix时间戳日期?
  10. 快速了解babel工作原理
  11. 文件系统[HDU-1413]
  12. OnlineDict:Chrome取词翻译扩展
  13. html添加哔哩哔哩视频,哔哩哔哩在线视频编辑器使用教程汇总
  14. Eviews创建散点图的具体方法
  15. 电视卡众说纷纭(二):2007年度市面常见电视卡软硬件性能
  16. 个人时间和任务管理工具GTD大盘点!你适合哪一款?
  17. 程序人生 - 车辆年检与费用你知道多少?
  18. 使用XMind编写测试用例
  19. python27安装get-pip
  20. 香农-范诺编码(Shannon–Fano Coding)

热门文章

  1. 灵活、高效、智慧,宁畅发布新品及“智定+”战略
  2. SRE 是如何保障稳定性的
  3. 深度揭秘:腾讯存储技术发展史
  4. 云漫圈 | 谈谈怎么做【服务隔离】
  5. 必须建筑师附体!像盖大楼那样打造数据即服务
  6. 福利 | 2018 OpenInfra Days China限量版免费票任性放出
  7. Service Mesh 在华为公有云的实践
  8. docker Redis集群
  9. ./mysqld: error while loading shared libraries: libaio.so.1: cannot open shared object file: No such
  10. RuoYi-Cloud 部署篇_04(windows环境 mysql+nginx版本)