train.py

总体分为三部分(不算import 库)
初始的一些设定 + train函数 + main函数
源码地址:
https://github.com/ultralytics/yolov3

一 .import 相关


torch.distributed 分布式训练
torch.optim.lr_scheduler 学习率衰减

二.初始设定

1使用混合精度训练

2预训练权重路径

Python可跨平台使用,但win和linux 的路径分隔符是不同的(斜杠)。
这里的os.sep会根据你所处的平台,自动采用相应的分隔符号。
results.txt 记录loss

3获取本地的超参数文件并解析

使用glob.glob来获取文件
4是否使用focal loss


三.main函数:

1参数的设定(之后涉及会说明)
这里是程序的默认参数

关于parser.add_argument的action=‘store_true’ 的解释是如果命令行执行程序时 设定对应的参数 那么这个参数的布尔值会变为True。
以这个为例:
parser.add_argument(’–nosave’, action=‘store_true’, help=‘only save final checkpoint’)
如果执行
train.py nosave
那么
opt.nosave 布尔值为True 一般涉及if判断 即 if opt.nosave :
2

p1 预训练权重的选择 如果设定了opt.resume就使用上次保存的权重 即命名带last.pt。else使用opt.weights路径下的权重

p2 一些check操作
check_git_status()函数 check一些版本 和 进程相关的内容。

p3 原本opt.img_size 设定的是一个img最大值和最小值 这里再添加一个test中img的size

p4 程序运行device的设定
如果opt.device 设定为GPU就使用混合精度计算。
设定GPU一般用数字id表示 0表示在第一个GPU上执行程序

3

if not:如果为False 就使用之前预设好的超参数进行训练 train(hyp)是进行正常的训练
之前设定的hyp是超参数搜索之后的超参数值。
使用opt.evolve判断是否使用超参数搜索的参数来训练

else:
以# Evolve hyperparameters (optional)方式来训练
以下暂略

四.train normal:


初始的一些设定
accumulate是用于判断模型是否进行更新参数

这个32是指模型的32前向传播会进行32倍的下采样率
也就是指32下采样后的featmap 的一个像素点 表示原图 32像素点的特征信息。详见原理
判断imgsz_min 是否可以整除32,刚好整除就不会出现 featmap 大小32.5 取整 为32 导致的信息损失。
assert函数 是判断后面是否出现异常。


解析运行程序时所对应的data文件
我感觉如果是训练自己的数据集 这个80要修改成对应的类别数,不知道是否是这样的。

移除之前保存的图片

初始化模型


attempt_download 用来获取权重 weights 是之前的opt.weights

if 权重文件是以pt结尾 也就是使用pytorch state_dict方式保存的形式。
这个pt文件包含四个部分(有的部分可能为空,就是写入的时候没有写入)
ckpt[weights]模型权重参数,ckpt[optimizer]优化器的参数,
ckpt[training_results]是这个权重文件中之前保存的结果 如果有就把里面的值写入我们这个程序的result.txt文件
ckpt[epoch] 这个权重文件中训练的epoch
def ckpt 释放一些内存

如果模型是YOLOV3 源码相关的格式保存的就用这种方式加载,比如我跑的预训练权重文件就叫
YOLOV3.weights


这段函数的含义就是是否进行freeze_layers操作,字面意思就是冻结一些层的权重参数,不去训练这些层,而是训练一些特定的层。这里训练yolo层 和 yolo层之后的一层。其他的层冻结操作。


关于混合精度训练的设定

学习率


分布式训练设定

初始化Dataset和Dataloader 用于训练


初始化Testloader 用于test

模型的参数设定
nc模型的类别数
hyp超参数
GIoU
gr 是涉及到giou的权重系数
原本使用iou来判别回归损失的时候有自身的缺点。1如果两个框没有相交,根据定义,IoU=0,不能反映两者的距离大小(重合度)。同时因为loss=0,没有梯度回传,无法进行学习训练。2不能判别两个标定框的重合程度。

c是指A和B的最小封闭形状。
GIoU公式 是指IoU 减去 ( |c中不包含A和B的区域| :|c的区域|)
这个公式可以表示两个标定框的重合程度。在IoU为0的时候也不会对更新参数有不好的影响。

EMA指数加权平均操作
也叫滑动平均值


计算公式 β×之前时刻数值 + (1-β)×此时的数字
当β越大时,滑动平均得到的值越和v的历史值相关


训练
一次迭代为一个周期

这里是判断是否要更改权重 这里没有使用。

把迭代dataloader的操作 使用tqdm 来封装 迭代的时候会显示进度条
ni是总体上已迭代的batch数
归一化操作

n_burn 是之前设定的参数
用于是否进行x interp操作

opt.multi_scale 是上面求得的,判断是否对图片进行随机缩放

前向传播的到模型的预测值


计算loss

backward

accumulate 之前设定的值 判断是否更新参数

把mloss,mem 等信息set到进度条显示中

ni = 0 时 也就是第一个batch 会保存一张带有模型预测信息的图片
到此迭代完一个batch

迭代完整体的dataloader时,也就是是一个epoch:


记录最好的 results

保存最好的模型权重 并释放ckpt

opt.name = ’ ’ 为空
len(opt.name) = 0
if 0 相当于 if False 不执行

时间不够 有的地方不具体 之后补充 欢迎留言

涉及的论文:
https://arxiv.org/pdf/1812.01187.pdf
https://arxiv.org/pdf/1812.01187.pdf
网站:
论文相关博客:
https://zhuanlan.zhihu.com/p/51870052
https://blog.csdn.net/diligent_321/article/details/87885418
超参数搜索
https://www.cnblogs.com/pprp/p/12432549.html
round()函数
https://blog.csdn.net/lly1122334/article/details/80596026
math.fmod()
https://blog.csdn.net/sunline_wanghj/article/details/79490986
check out()
https://www.cnblogs.com/hubavyn/p/8467329.html
GIoU
https://zhuanlan.zhihu.com/p/57992040
https://zhuanlan.zhihu.com/p/94799295
EMA:
https://blog.csdn.net/mikelkl/article/details/85227053
https://www.jianshu.com/p/f99f982ad370
分布式训练:
https://blog.csdn.net/m0_38008956/article/details/86559432?utm_source=blogxgwz4

pytorch YoLOV3 源码解析 train.py相关推荐

  1. YOLOv3源码解析2-数据预处理Dataset()

    YOLOv3源码解析1-代码整体结构 YOLOv3源码解析2-数据预处理Dataset() YOLOv3源码解析3-网络结构YOLOV3() YOLOv3源码解析4-计算损失compute_loss( ...

  2. Attention is all you need pytorch实现 源码解析01 - 数据预处理、词表的构建

    我们今天开始分析著名的attention is all you need 论文的pytorch实现的源码解析. 由于项目很大,所以我们会分开几讲来进行讲解. 先上源码:https://github.c ...

  3. MTCNN-tensorflow源码解析-gen_landmark_aug_12.py;gen_imglist_pnet.py

    gen_landmark_aug_12.py生成用于PNet网络的训练数据(用于人脸特征点).此外对于RNet,ONet(用于人脸特征点)的训练数据生成与其类似,不再赘述. 主函数: if __nam ...

  4. MTCNN-tensorflow源码解析-gen_12net_data.py

    prepare_data/gen_12net_data.py 生成训练样本,用于训练 PNet. anno_file = "wider_face_train.txt" #存放wid ...

  5. 3D点云深度学习PointNet源码解析——pointnet_cls.py.py

    参考博客: #这个文件实现了网络的分类结构.输出为B*40,是每个样本对于每个类别的概率.网络结构在get_model()中定义,loss则在get_loss中定义 import tensorflow ...

  6. yolov3之pytorch源码解析_springmvc源码架构解析之view

    说在前面 前期回顾 sharding-jdbc源码解析 更新完毕 spring源码解析 更新完毕 spring-mvc源码解析 更新完毕 spring-tx源码解析 更新完毕 spring-boot源 ...

  7. [源码解析] PyTorch 分布式(2) ----- DataParallel(上)

    [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 文章目录 [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 0x00 摘要 ...

  8. [源码解析] PyTorch 流水线并行实现 (6)--并行计算

    [源码解析] PyTorch 流水线并行实现 (6)–并行计算 文章目录 [源码解析] PyTorch 流水线并行实现 (6)--并行计算 0x00 摘要 0x01 总体架构 1.1 使用 1.2 前 ...

  9. [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

    [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎 文章目录 [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎 0x00 摘要 0x01 前文回顾 1.1 ...

最新文章

  1. R包ComplexHeatmap绘制个性化热图
  2. 管道符和作业控制、shell变量、环境变量配置文件
  3. opencv 4快速入门_茶知识|茶道核心4元素,看懂你也可以快速入门茶道!先收藏...
  4. windows编程,消息函数中拦截消息的问题
  5. P3157-[CQOI2011]动态逆序对【CDQ分治,树状数组】
  6. 域名和服务器销售WHMCS和HTML5模板 – Hostlar
  7. 系统描述符类型,段描述符类型和段描述符表
  8. Bailian1182 POJ1182 食物链【并查集】
  9. 完全公平调度 c语言,使用完全公平调度程序(CFS)进行多任务处理
  10. python2和python3中的map()
  11. matlab 透镜设计,一种用于均匀照明的LED透镜设计方法
  12. python 基于金字塔的图像融合
  13. web前端工程师等级分布
  14. java 背单词系统_背单词系统
  15. Windows开机启动项/自启动项文件夹位置
  16. 数的“平方”速算秘诀,超实用的技巧,3秒出答案
  17. 利用pyboardCN V2播放Bad apple
  18. leetcode 868. 二进制间距
  19. 【Bootstrap】<前端框架>Bootstrap布局容器栅格网格系统
  20. 【历史上的今天】12 月 16 日:晶体管问世;IBM 停售 OS/2;科幻小说巨匠诞生

热门文章

  1. 5809: 傻瓜电梯(思维模拟)
  2. 计算机画大熊猫教案,幼儿园大班教案:水墨画《大熊猫》
  3. 新股发行基本流程及网下投资者参与过程介绍
  4. SpringBoot项目集成全文搜索引擎Elasticsearch
  5. openid和unionid的区别
  6. 自己总结的HTMLcss
  7. 一个简单的购物类网站
  8. 量化投资学习-8:庄家与散户是一场猫抓老鼠的游戏
  9. 树莓派安装Opencv+Contrib和 QT 详细教程
  10. 红蓝对抗中的钓鱼技术