掌握Pytorch模型 压缩 裁剪与量化
在深度学习模型的搭建和部署中,我们需要考虑到模型的权重个数、模型权重大小、模型推理速度和计算量。本文将分享在Pytorch中进行模型压缩、裁剪和量化的教程。
权重压缩
模型在训练时使用的模型权重类型为float32
,而在模型部署时则不需要高的数据精度。可以将类型转换为float16
进行保存,这样可以降低45%左右的权重大小。
- 步骤1:训练并保存模型
import timm
model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))
- 步骤2:转换数据类型并存储
params = torch.load('model_mobilevit_xxs.pth') # float32
for key in params.keys():params[key] = params[key].half() # float16torch.save(params, 'model_mobilevit_xxs_half.pth')
权重裁剪
在模型训练完成后可以考虑对冗余的权重进行裁剪,有以下几种裁剪方法:
- 按照比例随机裁剪
- 按照权重大小裁剪
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
使用的案例代码如下:
import torch.nn.utils.prune as prune
import numpy as npmodel = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))# 选中需要裁剪的层
module = model.head.fc# random_unstructured裁剪
prune.random_unstructured(module, name="weight", amount=0.3)# l1_unstructured裁剪
prune.l1_unstructured(module, name="weight", amount=0.3)# ln_structured裁剪
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
在使用权重裁剪需要注意:
- 权重裁剪并不会改变模型的权重大小,只是增加了稀疏性;
- 权重裁剪并不会减少模型的预测速度,只是减少了计算量;
- 权重裁剪的参数比例会对模型精度有影响,需要测试和验证;
权重量化
32-bit的乘加变成了8-bit的乘加,模型权重大小减少,对内存的要求降低了。
https://pytorch.org/docs/stable/quantization.html
Eager Mode Quantization
import torch# define a floating point model
class M(torch.nn.Module):def __init__(self):super(M, self).__init__()self.fc1 = torch.nn.Linear(100, 40)self.fc2 = torch.nn.Linear(1000, 400)def forward(self, x):x = self.fc1(x)return x# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(model_fp32, # the original model{torch.nn.Linear}, # a set of layers to dynamically quantizedtype=torch.qint8) # the target dtype for quantized weights# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')
Post Training Static Quantization
import torch# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super(M, self).__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 100, 1)self.relu = torch.nn.ReLU()self.fc = torch.nn.Linear(100, 10)# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.quantization.DeQuantStub()def forward(self, x):# manually specify where tensors will be converted from floating# point to quantized in the quantized modelx = self.quant(x)x = self.conv(x)x = self.relu(x)# manually specify where tensors will be converted from quantized# to floating point in the quantized modelx = self.dequant(x)return x# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')model_fp32.eval()model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')
Pytorch暂时的量化操作还不是很完善,可能存在只能在CPU上运行,且速度变慢的情况。如果有量化需求,推荐使用tensorrt和GPU一起使用。
掌握Pytorch模型 压缩 裁剪与量化相关推荐
- PyTorch 深度学习模型压缩开源库(含量化、剪枝、轻量化结构、BN融合)
点击我爱计算机视觉标星,更快获取CVML新技术 本文为52CV群友666dzy666投稿,介绍了他最近开源的PyTorch模型压缩库,该库开源不到20天已经收获 219 颗星,是最近值得关注的模型压缩 ...
- 腾讯 AI Lab 正式开源PocketFlow自动化深度学习模型压缩与加速框架
11月1日,腾讯AI Lab在南京举办的腾讯全球合作伙伴论坛上宣布正式开源"PocketFlow"项目, 该项目是一个自动化深度学习模型压缩与加速框架,整合多种模型压缩与加速算法并 ...
- 阿里云PAI平台模型压缩技术落地淘宝直播双十一应用:一猜到底
简介:随着移动端应用的兴起,模型压缩作为深度学习模型实现轻量化部署的有效手段,备受关注.机器学习也从理论研究阶段,有了明显的工程化.应用落地的趋势,那么模型压缩在淘宝直播游戏场景下,是如何发挥重要作用 ...
- 使用Mindstudio进行Pytorch模型量化压缩
视频教程在模型量化压缩(Pytorch)_哔哩哔哩_bilibili MindStudio介绍与安装流程 1.1基本介绍: MindStudio为用户提供在AI开发所需的一站式开发环境,支持模型开发. ...
- 闲话模型压缩之量化(Quantization)篇
1. 前言 这些年来,深度学习在众多领域亮眼的表现使其成为了如今机器学习的主流方向,但其巨大的计算量仍为人诟病.尤其是近几年,随着端设备算力增强,业界涌现出越来越多基于深度神经网络的智能应用.为了弥补 ...
- PyTorch模型量化工具学习
官方教程(英文): https://pytorch.org/docs/stable/quantization.htmlpytorch.org 官方教程(中文): https://pytorch.ap ...
- Intel发布神经网络压缩库Distiller:快速利用前沿算法压缩PyTorch模型
Intel发布神经网络压缩库Distiller:快速利用前沿算法压缩PyTorch模型 原文:https://blog.csdn.net/u011808673/article/details/8079 ...
- 【视频课】深度掌握模型剪枝+模型量化+知识蒸馏3大核心模型压缩技术理论!...
前言 欢迎大家关注有三AI的视频课程系列,我们的视频课程系列共分为5层境界,内容和学习路线图如下: 第1层:掌握学习算法必要的预备知识,包括Python编程,深度学习基础,数据使用,框架使用. 第2层 ...
- 量化感知训练实践:实现精度无损的模型压缩和推理加速
简介:本文以近期流行的YOLOX[8]目标检测模型为例,介绍量化感知训练的原理流程,讨论如何实现精度无损的实践经验,并展示了量化后的模型能够做到精度不低于原始浮点模型,模型压缩4X.推理加速最高2.3 ...
最新文章
- 常用元素位置与大小总结
- <script>放在head内和body内有什么区别
- 这是我第一题AC的线段树
- 利用scipy包计算表格线的峰值,还原表格得到表格结构
- Yolov5系列AI常见数据集(1)车辆,行人,自动驾驶,人脸,烟雾
- string 方法 java_String 的几个 方法。 (java)
- win8.1 服务器正在运行,Win8.1系统打开IE浏览器提示服务器正在运行中的解决方法图文教程...
- stm32cubeIDE下载无法打开GDB的问题
- PhpSpreadsheet设置单元格常用操作汇总
- 为什么说精益管理模式是适合中国企业的管理方法(zt)
- 真分数化简为最简分数(6/8==3/4)
- 1、目标检测 RCNN(翻译+标注)
- win10系统怎么打开pdf文件
- mysql能够跨平台使用吗_Mysql跨平台(Windows,Linux,Mac)使用与安装
- 关于椰子汁的学问,你知道多少?
- 想要快速绘制3D图纸?这些“私人定制”不可少!
- (BAT批处理)批量文件夹重命名,要求是在原文件夹名前加上英文字母前缀aa
- 1267 'Illegal mix of collations (latin1_swedish_ci,IMPLICIT) and (utf8_gener
- 什么是高内聚与低耦合?
- 使用Python3开发的一款Android截屏神器
热门文章
- DCMM认证评估机构,你都知道吗?
- 递归经典问题:迷宫以及八皇后
- 2020年,冯唐49岁:我给20、30岁IT职场年轻人的建议
- Shiro--解决is not eligible for getting processed by all BeanPostProcessors
- 基于javaweb的进销存管理系统(前后端分离+java+vue+springboot+ssm+mysql+redis)
- Google Android Developer
- 广度优先搜索(BFS)---农夫与牛
- 用户在Eightcap易汇平台可以交易哪些产品?投资选择多吗?
- 苹果云服务icloud_苹果手机icloud手动备份和还原个人用户资料
- 《Python3》读书笔记(上)