一、结构分布

先介绍一下代码的结构分布吧

1、tain.py文件是训练的时候首先执行的文件,里面的函数有eval()评估函数,train()训练函数

2、trainer.py文件是网络的流图,关于如何forward,如何计算loss,如何反向计算,如何保存模型,如何控制权重更新等等,这个里面的函数会在train.py中的train()函数开始的时候调用,先构建fasterrcnn的网络,然后将网络作为参数传给trainer的构造函数

3、data文件夹,下面都是数据读取的方法

(1)dataset.py文件是批量加载数据,这个也会在train()函数开始初始化了一个dataset对象。

(2)voc_dataset.py文件是针对VOC数据集格式准备的,批量加载VOC数据集,解析XML文件都在这个类中,而且他是在dataset.py文件中调用,加载VOC数据集用的。

(3)util.py文件是定义了一些图像预处理工具,包括read_image,resize_box、crop_bbox这些会在其他py文件中调用,比如voc_dataset.py文件中调用了read_image数据进行读取数据。

(4)__init__.py文件好像是python要求类下必须有这个文件,需要确定一下???

4、model文件夹,下面都是网络构建的一些py。

(1)faster_rcnn_vgg16.py文件,是用来构建FasterRCNN-vgg16网络的,该网络分三部分创建,extractor特征提取网络,是利用torchvision.model模块创建的VGG16,然后是RPN网络创建,再就是ROIHeader网络创建,这个网络构建对象在train()函数开始的时候就创建了,用于先构建网络,然后再传入trainer。

(2)faster_rcnn.py文件,是一个base-class,faster-rcnn-vgg16类继承了这个类

(3)region_proposal_network.py文件,用于构建rpn模块,在faster_rcnn_vgg16.py文件中调用生成网络结构。

(4)roi_module.py文件,这个暂时没研究,在faster_rcnn_vgg16.py中调用了,初始化的时候ROIPooling。??

(5)utils文件夹,下面是一些工具,nms文件夹是非极大值抑制,其他的没仔细看,后期研究,主要是在faster_rcnn.py文件中调用了。

6、utils文件夹,这里面是工具

二、训练流程

1、首先调用train.py文件,输入相关参数进行训练。输入的控制台参数用**kwargs来表示,学习了python的控制台参数知道这是个接受字典形参数。https://www.cnblogs.com/zhangzhuozheng/p/8053045.html可参见这个地址有详细说明。

2、进行参数解析,利用了utils文件夹下的config.py文件进行了参数获取,这个文件中自定义一些默认参数,主要返回的是学习率学习策略,数据集地址等。

3、构造数据集对象,包括标签、图像名称列表。

4、根据batch_size,num_workers进行数据加载对象声明。数据是不是这个时候加载的还待定,感觉这个loader就像一个占位符,先占个坑,等运行网络的时候就开始读入了。具体DataLoader的用法需要查看pytorch

5、然后构建网络faster_rcnn_vgg16,构建方式见前面说的faster_rcnn_vgg16.py。

6、构建traner对象,将网络输入进行。

7、判断opt.load_path是否存在,这个load_path是在config.py中定义的,是model的地址,默认值是none,也就是如果在控制台输入没有指定--model这个参数,那么就没有了。如果有model即预训练模型,则调用trainer中的load函数加载预训练模型。

(1)trainer.load()函数解析,首先利用torch.load()函数加载模型,然后判断‘model’字符串是否存在,来判断是单纯加载参数还是加载带模型的参数(这是我个人理解的,具体要看pytorch的load_state_dict函数),最后判断参数是否修改,默认没改,最后判断优化器是否在加载的网络里,是的话加载预训练模型中的优化器。

8、可视化训练数据的label,调用trainer.vis.text函数,函数解析待会。

9、best_map参数干什么用的不知道待定,lr_是学习率获取。

10、下面就是循环训练啦,循环条件是epoch数,这个是opt超参数规定的。

(1)trainer.reset_meters()先重置界面上所有的数据,相当于一个epoch更新一次显示数据。

(2)开启一个for循环,枚举数据啦,从dataloader中按照batch-size循环读取数据,循环条件是把数据取完,tqdm模块是进度条模块具体可以百度。

(3)然后调用array_tool.py文件(在utils文件夹下)中的scalar()函数,传入参数是scale,这个参数是什么意思

(4)把数据传到cuda中,用来加速计算,返回转换后的cuda版本的数据,下面调用的都是cuda版的。

(5)利用trainer.train_step函数进行计算,前面介绍过这个函数,是用来更新一次权重的。

(6)图像进行归一化处理.

(7)然后是显示

(8)预测bboxes,label,这个predict函数是哪来的呢,首先是trainer,而trainer中调用的网络是fasterrcnnvgg16,这个网络继承 的是fasterrcnn的类,fasterrcnn类中有一个predict函数。

(9)下面就是一些可视化操作了,然后跳出了枚举数据的循环。这就是1个epoch完成了

(10)模型评估,利用测试集来做,前面已经加载了测试集,test_dataloader

(11)得到优化器中学习率的数值,并显示日志相关内容,包括lr,map,loss

(12)根据评测结果判断map是否是大于阈值best_map,如果是保存模型

(13)判断当前的epoch是否=9,如果是就加载最好的map和改变学习率

(14)判断如果epoch=13就跳出迭代循环???这个是这个实验里设计的具体原因不清楚。待定

FasterRCNN-pytorch的代码解析相关推荐

  1. Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch)

    Temporal Fusion Transformer (TFT) 各模块功能和代码解析(pytorch) 文章目录 Temporal Fusion Transformer (TFT) 各模块功能和代 ...

  2. pytorch代码解析:loss = y_hat - y.view(y_hat.size())

    pytorch代码解析:pytorch中loss = y_hat - y.view(y_hat.size()) import torchy_hat = torch.tensor([[-0.0044], ...

  3. Hugging Face实战(NLP实战/Transformer实战/预训练模型/分词器/模型微调/模型自动选择/PyTorch版本/代码逐行解析)下篇之模型训练

    模型训练的流程代码是不是特别特别多啊?有的童鞋看过Bert那个源码写的特别特别详细,参数贼多,运行一个模型百八十个参数的. Transformer对NLP的理解是一个大道至简的感觉,Hugging F ...

  4. 单目标跟踪算法:Siamese RPN论文解读和代码解析

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者:周威 | 来源:知乎 https://zhuanlan.zhihu.com/p/16198364 ...

  5. Faster-RCNN.pytorch的搭建、使用过程详解(适配PyTorch 1.0以上版本)

    Faster-RCNN.pytorch的搭建.使用过程详解 引言 faster-rcnn pytorch代码下载 faster-rcnn pytorch配置过程 faster-rcnn pytorch ...

  6. Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 1

    Positional Encodings in ViTs 近期各视觉Transformer中的位置编码方法总结及代码解析 最近CV领域的Vision Transformer将在NLP领域的Transo ...

  7. [GCN] 代码解析 of GitHub:Semi-supervised classification with graph convolutional networks

    本文解析的代码是论文Semi-Supervised Classification with Graph Convolutional Networks作者提供的实现代码. 原GitHub:Graph C ...

  8. 目标检测算法之常见评价指标的详细计算方法及代码解析

    前言 之前简单介绍过目标检测算法的一些评价标准,地址为目标检测算法之评价标准和常见数据集盘点.然而这篇文章仅仅只是从概念性的角度来阐述了常见的评价标准如Acc,Precision,Recall,AP等 ...

  9. YOLO系列 --- YOLOV7算法(二):YOLO V7算法detect.py代码解析

    YOLO系列 - YOLOV7算法(二):YOLO V7算法detect.py代码解析 parser = argparse.ArgumentParser()parser.add_argument('- ...

  10. YOLO-V5 算法和代码解析系列 —— 学习路线规划综述

    目录标题 为什么学习 YOLO-V5 ? 博客文章列表 面向对象 开源项目学习方法 预备知识 项目目录结构 为什么学习 YOLO-V5 ? 算法性能:与YOLO系列(V1,V2,V3,V4)相比,YO ...

最新文章

  1. Spring JDBC-混合框架的事务管理
  2. 神经网络 深度学习 专业术语解释(Step, Batch Size, Iteration,Epoch)
  3. java 取整_javascript 解决默认取整的坑(目前已知的最佳解决方案)
  4. BZOJ1146 [CTSC2008]网络管理Network 树链剖分 主席树 树状数组
  5. arcball原理 旋转视图 关键点总结 及代码
  6. oracle账户解锁28000,oracle 下载 账号密码ORA-28000账户被锁和解锁
  7. leetcode1045. 买下所有产品的客户(SQL)
  8. TCP如何保证可靠性
  9. NPM包管理器跟换国内镜像CNPM
  10. payload的使 常用xss_Sony某个深度子域上的XSS
  11. 谷歌修复又一枚遭在野利用的 Chrome 0day
  12. python通讯录的录入与测试_python实现手机通讯录搜索功能
  13. windows远程连接linux中mysql数据库
  14. 基于jsp+java+ssm的大学生缴费系统
  15. vue手机端打开高德地图app
  16. linux vi文件提示swp,如何解决非正常关闭vi编辑器时生成.swp文件问题
  17. 创业公司怎样才能有效的进行员工股权激励
  18. C#调用DLL的几种方法
  19. 高性能计算机部件有,高性能计算及高性能计算机-超级计算中心.ppt
  20. 【清华牛人】Stanford, Caltech双料博士

热门文章

  1. C语言结构体的存储空间分配
  2. VLSI数字信号处理系统——第二章迭代边界
  3. erp5开源制造业erp短信发送接收机制
  4. k8s还能这么玩?快速上手物联网应用的容器开发
  5. 2022年货节有什么数码家电推荐的,2022年货节数码家电购物清单
  6. Python自动化小技巧13——批量下载北交所上市公司年报
  7. 使用DeepEarth加载在线Google地图(卫星、街道)
  8. bilibili如何空降
  9. uniapp一键链接指定WiFi功能
  10. VB.NET C#枚举 描述 中文 ComponentModel.Description DescriptionAttribute