目录

  • YOLOv8剪枝
    • 前言
    • 1.Overview
    • 2.Pretrain(option)
    • 3.Constrained Training
    • 4.Prune
      • 4.1 检查BN层的bias
      • 4.2 设置阈值和剪枝率
      • 4.3 最小剪枝Conv单元的TopConv
      • 4.4 最小剪枝Conv单元的BottomConv
      • 4.5 Seq剪枝
      • 4.6 Detect-FPN剪枝
      • 4.7 完整示例代码
    • 5.YOLOv8剪枝总结
    • 总结

YOLOv8剪枝

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解YOLOv8剪枝。

课程大纲可看下面的思维导图

1.Overview

YOLOV8剪枝的流程如下:

结论:在VOC2007上使用yolov8s模型进行的实验显示,预训练和约束训练在迭代50个epoch后达到了相同的mAP(:0.5)值,约为0.77。剪枝后,微调阶段需要65个epoch才能达到相同的mAP50。修建后的ONNX模型大小从43M减少到36M。

注意:我们需要将网络结构和网络权重区分开来,YOLOv8的网络结构来自yaml文件,如果我们进行剪枝后保存的权重文件的结构其实是和原始的yaml文件不符合的,需要对yaml文件进行修改满足我们的要求。

2.Pretrain(option)

步骤如下:

  • git clone https://github.com/ultralytics/ultralytics.git
  • use VOC2007, and modify the VOC.yaml(去除VOC2012的相关内容)
  • disable amp(禁用amp混合精度)
# FILE: ultralytics/yolo/engine/trainer.py
...
def check_amp(model):# Avoid using mixed precision to affect finetunereturn False # <============== modified(修改部分)device = next(model.parameters()).device  # get model deviceif device.type in ('cpu', 'mps'):return False  # AMP only used on CUDA devicesdef amp_allclose(m, im):# All close FP32 vs AMP results...

3.Constrained Training

约束训练是为了筛选哪些channel比较重要,哪些channel没有那么重要,也就是我们上节课所说的稀疏训练

  • prune the BN layer by adding L1 regularizer.
# FILE: ultralytics/yolo/engine/trainer.py
...
# Backward
self.scaler.scale(self.loss).backward()# <============ added(新增)
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni
...

注意1:在剪枝时,我们选择加载last.pt而非best.pt,因为由于迁移学习,模型的泛化性比较好,在第一个epoch时mAP值最大,但这并不是真实的,我们需要稳定下来的一个模型进行prune

注意2:我们在对Conv层进行剪枝时,我们只考虑1v1(如BottleNeck,C2f and SPPF)、1vm(如Backbone,Detect)的情形,并不考虑mv1的情形。

思考:Constrained Training的必要性?

约束训练可以使得模型更易于剪枝。在约束训练中,模型会学习到一些通道或者权重系数比较不重要的信息,而这些信息在剪枝过程中得到应用,从而达到模型压缩的效果。而如果直接进行剪枝操作,可能会出现一些问题,比如剪枝后的模型精度大幅下降、剪枝不均匀等。因此,在进行剪枝操作之前,通过稀疏训练的方式,可以更好地准确地确定哪些通道或者权重系数可以被剪掉,从而避免上述问题的发生。

4.Prune

4.1 检查BN层的bias

  • 剪枝后,确保BN层的大部分bias足够小(接近于0),否则重新进行稀疏训练
for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())

4.2 设置阈值和剪枝率

  • threshold:全局或局部
  • factor:保持率,裁剪太多不推荐
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)

4.3 最小剪枝Conv单元的TopConv

Top-Bottom Conv如下图所示:

TopConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta  = conv1.bn.bias.data.detach()keep_idxs = []    local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0] local_threshold = local_threshold * 0.5n = len(keep_idxs)print(n / len(gamma) * 100)  # 打印我们保留了多少的channel# pruneconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data   = beta[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.running_var.data  = conv1.bn.running_var.data[keep_idxs]conv1.bn.num_features   = nconv1.conv.weight.data  = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]# pattern to prune
# 1. prune all 1 vs 1 TB pattern e.g. bottleneck
for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)

注意:由于NVIDIA的硬件加速的原因,我们保留的channels应该大于等于8,我们可以通过设置local_threshold,尽量小点,让更多的channel保留下来。

4.4 最小剪枝Conv单元的BottomConv

BottomConv剪枝的示例代码如下:

def prune_conv(conv1: Conv, conv2: Conv):...if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]

注意BottomConv存在两种情形,一种是单个Conv,还有一种是Conv列表。TopConv需要考虑conv2d+bn+act,而BottomConv只需要考虑conv2d

4.5 Seq剪枝

Seq剪枝的示例代码如下:

def prune(m1, m2):if isinstance(m1, C2f):m1 = m1.cv2if not isinstance(m2, list):m2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)# 2. prune sequential
seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i+1])

注意:我们不考虑1vm的情形,因此在4,6,9的module我们是不进行剪枝的,后续head进行Concat时是对4,6,9的module进行拼接的。考虑到前几个conv的特征提取的重要性,我们也不剪枝它们。(那感觉没剪几个module呀

剪枝与重参第七课:YOLOv8剪枝相关推荐

  1. 剪枝与重参第三课:常用剪枝工具

    目录 常用剪枝工具 前言 1.torch.nn.utils.prune 1.1 API简单示例 1.2 拓展之钩子函数 2.pytorch pruning functions 3.custom pru ...

  2. 剪枝与重参第二课:修剪方法和稀疏训练

    目录 修剪方法和稀疏训练 前言 1.修剪方法 1.1 经典框架:训练-剪枝-微调 1.2 训练时剪枝(rewind) 1.3 removing剪枝 2.dropout and dropconnect ...

  3. 《幸福就在你身边》第七课、工作着,快乐着【哈佛大学幸福课精华】

    一.开心工作 愚人向远方寻找快乐,智者则在身旁培养快乐. 泰戈尔在<人生的亲证>中写道:"我们的工作日不是我们的欢乐日--因此,我们要求节日,我们在自己的工作中不能找到节日,所以 ...

  4. 初二计算机简单动画,浙教版八年级下册信息技术:第七课《简单的动画补间动画》教案...

    ID:10051834 分类: 全国 , 2019 资源大小:219KB 资料简介: 第七课<简单的动画补间动画> 课题 第六课  简单的动画补间动画 目标 1.通过设置舞台背景和角色,学 ...

  5. 第七课 大数据技术之Fink1.13的实战学习-Fink CEP

    第七课 大数据技术之Fink1.13的实战学习-Fink CEP 文章目录 第七课 大数据技术之Fink1.13的实战学习-Fink CEP 第一节 Fink CEP介绍 1.1 Flink CEP背 ...

  6. 第七课.简单的图像分类(一)

    第七课目录 图像分类基础 卷积神经网络 Pooling layer BatchNormalization BatchNormalization与归一化 torch.nn.BatchNorm2d MNI ...

  7. C#之windows桌面软件第七课:(下集)串口工具实现数据校验、用灯反应设备状态

    C#之windows桌面软件第七课:(下集)串口工具实现数据校验.用灯反应设备状态 using System; using System.Collections.Generic; using Syst ...

  8. Coursera公开课笔记: 斯坦福大学机器学习第七课“正则化(Regularization)”

     Coursera公开课笔记: 斯坦福大学机器学习第七课"正则化(Regularization)" +13投票 斯坦福大学机器学习第七课"正则化"学习笔记, ...

  9. Asp.Net Web API 2第七课——Web API异常处理

    Asp.Net Web API 2第七课--Web API异常处理 原文:Asp.Net Web API 2第七课--Web API异常处理 前言 阅读本文之前,您也可以到Asp.Net Web AP ...

最新文章

  1. getURLParameters - 网址参数
  2. 怎样使破解网页的禁止复制黏贴
  3. mipi LCD 的CLK时钟频率与显示分辨率及帧率的关系
  4. 彻底理解JAVA动态代理
  5. 设定printf在终端输出的颜色
  6. 卷积神经网络CNNs 为什么要用relu作为激活函数?
  7. python嵌入到C++中
  8. LeTax如何多行注释
  9. controller属于哪一层_别急着换5G,4G手机同样值得考虑!哪几款安卓手机称得上4G机皇?...
  10. JavaScript————FormData实现多文件上传
  11. 基于Spring MVC的ECharts动态数据实时展示
  12. 2016全国计算机二级题,2016全国计算机二级考生试题及答案
  13. nagios配置文件说明
  14. React学习笔记二 通过柯里化函数实现带参数的事件绑定
  15. ORACLE多表关联的update语句
  16. iptables详解 1 -- iptables概念
  17. 用长按键重复输入 - Mac OS X Lion
  18. python爬贴吧回复_Python爬虫实践,获取百度贴吧内容
  19. 黑苹果 MacOS 10.15 Catalina 安装详细教程带工具资料
  20. List集合和ArrayList集合源码

热门文章

  1. selenium防爬和模拟手机浏览器
  2. Django做一个简单的博客系统(10)----最热文章
  3. 解决手机连接上wifi可以上网,电脑不上不了网的问题
  4. HR做好背景调查的五大节点!
  5. 浅谈QQ营销之月收入上万不是梦
  6. php透视图,第五十七课 利用透视尺绘制透视图-透视尺基本篇2-
  7. Python学习:字符串的深入浅出
  8. android手机防盗图片,android手机防盗措施介绍【图文】
  9. [教程]HTML5+Bootstrap4+Spring Boot+Mysql 图书管理系统 (附源码)
  10. 星环科技发布工业互联网解决方案,场景化赋能制造业转型升级