当PyTorch模型需要部署到服务时,为了提升访问速度,需要转换为TRT模型,再进行部署。在转换为TRT模型之前,需要将PyTorch参数模型(如pth.tar)转换为pt模型,使用jit形式。pt模型 = 参数模型(pth.tar) + 网络结构(如resnet50)。使用pt模型,可以简化使用方式,同时也方便转换为trt模型,进行轻量级部署。在转换函数中,包含验证逻辑,保证转换前后的模型效果一致,即输出不变。

以图像分类框架pytorch-image-models-my为例,将PyTorch的pth.tar模型转换为PT模型。

转换流程如下:

  1. 加载pth.tar模型model,model达到可以预测的标准,即:
# 加载模型
model = timm.create_model(model_name=base_net, pretrained=False,checkpoint_path=model_path, num_classes=num_classes)
if torch.cuda.is_available():print('[Info] cuda on!!!')model = model.cuda()
model.eval()# 预测结果
print('[Info] 预测图像尺寸: {}'.format(img_rgb.shape))
img_tensor = self.preprocess_img(img_rgb, self.transform)
print('[Info] 模型输入: {}'.format(img_tensor.shape))
with torch.no_grad():out = self.model(img_tensor)
  1. 将已加载的模型model,通过torch.jit.trace()模拟输入dummy_input,调用traced.save()存储成pt模型,即:

    • 注意输入尺寸dummy_shape,用于生成模拟的input数据,需要与模型输入保持一致

    • 注意是否支持GPU,即orch.cuda.is_available(),判断环境是cuda还是cpu。

dummy_shape = (1, 3, 336, 336)  # 不影响模型
print('[Info] dummy_shape: {}'.format(dummy_shape))
if torch.cuda.is_available():model_type = "cuda"
else:model_type = "cpu"
print('[Info] model_type: {}'.format(model_type))
dummy_input = torch.empty(dummy_shape,dtype=torch.float32,device=torch.device(model_type))
traced = torch.jit.trace(self.model, dummy_input)
pt_path = os.path.join(pt_folder_path, "{}_{}.pt".format(model_name, model_type))
traced.save(pt_path)
  1. 验证pt模型是否与原模型pth.tar的输出是否一致,pt模型调用reload_script(),即:
with torch.no_grad():standard_out = self.model(dummy_input)
print('[Info] standard_out: {}'.format(standard_out))reload_script = torch.jit.load(pt_path)
with torch.no_grad():script_output = reload_script(dummy_input)
print('[Info] script_output: {}'.format(script_output))
print('[Info] 验证 is equal: {}'.format(F.l1_loss(standard_out, script_output)))print('[Info] 存储完成: {}'.format(pt_path))

全部转换和验证PT模型的逻辑,都位于save_pt()函数中,调用即可生成,输出位于pt_models文件夹中,即:

me.save_pt(os.path.join(DATA_DIR, "pt_models"))

输出的模型是:model_best_c2_20210915_cpu.pt,GPU版本是:model_best_c2_20210915_cuda.pt

在pytorch-image-models-my工程中,pth.tar模型转换为PT模型的转换脚本,源码如下,参考model_2_pt_script.py:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2021. All rights reserved.
Created by C. L. Wang on 15.9.21
"""
import argparse
import os
import sysp = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:sys.path.append(p)from root_dir import DATA_DIR
from myscripts.img_predictor import ImgPredictordef parse_args():"""处理脚本参数"""parser = argparse.ArgumentParser(description='PyTorch模型转换PT模型')parser.add_argument('-m', dest='model_path', required=True, help='模型路径', type=str)parser.add_argument('-n', dest='base_net', required=False, help='basenet', type=str, default="resnet50")parser.add_argument('-c', dest='num_classes', required=False, help='类别个数', type=int, default=2)parser.add_argument('-o', dest='out_dir', required=False, help='输出文件夹', type=str,default=os.path.join(DATA_DIR, "pt_models"))args = parser.parse_args()arg_model_path = args.model_pathprint("[Info] 模型路径: {}".format(arg_model_path))arg_base_net = args.base_netprint("[Info] basenet: {}".format(arg_base_net))arg_num_classes = args.num_classesprint("[Info] 类别数: {}".format(arg_num_classes))arg_out_dir = args.out_dirprint("[Info] 输出文件夹: {}".format(arg_out_dir))return arg_model_path, arg_base_net, arg_num_classes, arg_out_dirdef main():"""入口函数"""print('[Info] ' + "-" * 100)print('[Info] 转换PT模型开始')arg_model_path, arg_base_net, arg_num_classes, arg_out_dir = parse_args()me = ImgPredictor(arg_model_path, arg_base_net, arg_num_classes)pt_path = me.save_pt(arg_out_dir)  # 存储PT模型print('[Info] 存储完成: {}'.format(pt_path))print('[Info] ' + "-" * 100)if __name__ == "__main__":main()

PyTorch参数模型转换为PT模型相关推荐

  1. pytorch训练的pt模型转换为onnx(nn.DataParallel()、model、model.state_dict())

    pt转onnx流程与常见问题 pt转onnx流程 pt转onnx流程 1.读取pt模型文件,文件既可以是torch.save(model,path)整体保存的模型,也可以是保存的字典文件. // An ...

  2. pytorch将pt模型转onnx模型

    pytorch将pt模型转onnx模型 一 导出ONNX模型 torch.onnx.export( model, # 要导出的模型 args, # 模型的输入参数,输入参数只需满足shape正确 on ...

  3. [Pytorch].pth转.pt文件

    Pytorch的模型文件一般会保存为.pth文件,C++接口一般读取的是.pt文件,因此,C++在调用Pytorch训练好的模型文件的时候就需要进行一个转换,转换为.pt文件,才能够读取. 所以在转换 ...

  4. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  5. Pytorch搭建自己的模型

    前言 PyTorch.TensorFlow都是主流的深度学习框架,今天主要讲解一下如何快速使用pytorch搭建自己的模型.至于为什么选择讲解pytorch,这里我就简单说明一下自己的使用感受(相对T ...

  6. PyTorch的生态和模型部署

    PyTorch的生态和模型部署 1. PyTorch生态 前几章,我们学习了PyTorch的基本使用.能够定义和修改自己的模型.常用的训练技巧和PyTorch的可视化. PyTorch的强大,跟PyT ...

  7. 【项目实战课】从零掌握安卓端Pytorch原生深度学习模型部署

    欢迎大家来到我们的项目实战课,本期内容是<从零掌握安卓端Pytorch原生深度学习模型部署>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战 ...

  8. 基于Pytorch的Transformer翻译模型前期数据处理方法

    基于Pytorch的Transformer翻译模型前期数据处理方法 Google于2017年6月在arxiv上发布了一篇非常经典的文章:Attention is all you need,提出了解决s ...

  9. Pytorch完成基础的模型-线性回归

    Pytorch完成基础的模型-线性回归 1. Pytorch完成模型常用API 在前一部分博文中,实现了通过torch的相关方法完成反向传播和参数更新,在pytorch中预设了一些更加灵活简单的对象, ...

最新文章

  1. linux源码包卸载方式
  2. 计算机辅助园林设计ps,计算机辅助园林设计III
  3. MVC/MVP/MVVM区别——MVVM就是angular,视图和数据双向绑定
  4. 小猿圈讲解Java可以做什么?
  5. LeetCode-剑指 Offer 11. 旋转数组的最小数字
  6. NHibernate中的SchemaExport
  7. Auto-Configuration Error: Cannot find gcc or CC
  8. 前端学习(1957)vue之电商管理系统电商系统之创建新分支
  9. oenwrt 进不了bios_为什么进不bios_进不了bios怎么解决?
  10. async await 的前世今生
  11. SQL Server “Denali” ---SQL 2012 新特性
  12. 多点Dmall发布系统Mini OS 宣称要五年覆盖百万门店
  13. c++ xml 解析“后直接跟值问题
  14. eCos configtool 在ubuntu 10.10以后菜单消失的解决
  15. 关于计算机四级网络工程师的考试
  16. 《图解网络硬件》网络硬件通用基础知识
  17. 2021广东高考成绩排名如何查询,2021广东省地区高考成绩排名查询,广东省高考各高中成绩喜报榜单...
  18. Excel VBA 操作键盘(如:移动方向键,上下左右等)
  19. 天黑请闭眼,我这次还能抽到杀手吗
  20. 写需求规格说明书/产品定义的个人总结

热门文章

  1. C++ 与、或、异或、取反等运算
  2. 网络规划和设计 - 关键路径法 CPM(关键路径、松弛时间)
  3. 一个测试人员的工作该怎么开展
  4. RL(十三)深度Q网络(DQN)
  5. 一个基于Android开发的简单的音乐播放器
  6. 说话前你是话的主人,说话后你是话的仆人
  7. 工控机与商用计算机的区别,两招教你分辨工控机与普通电脑的区别
  8. 众利币开发与模式设计
  9. python问题解决----把\xe6\xa8\xa1\这种字符转为普通汉字?
  10. 网页设计与制作的学习(一)