一、torch.onnx.export()详细介绍

1.torch.onnx.export()

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None)

2. 功能:

将pth模型转为onnx文件导出。

3.参数

  • model (torch.nn.Module) :pth模型文件;
  • args (tuple of arguments) :模型的输入, 模型的尺寸;
  • export_params (bool, default True) – 如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False;
  • verbose (bool, default False) :导出轨迹的调试描述;
  • training (bool, default False) :在训练模式下导出模型。目前,ONNX导出的模型只是为了做推断,通常不需要将其设置为True;
  • input_names (list of strings, default empty list) :onnx文件的输入名称;
  • output_names (list of strings, default empty list) :onnx文件的输出名称;
  • opset_version:默认为9;
  • dynamic_axes – {‘input’ : {0 : ‘batch_size’}, ‘output’ : {0 : ‘batch_size’}}) 。

二、pth的保存方式

torch.save(model,'save_path')

torch.save(model,path) 会将model的参数、框架都保存到路径path中,但是在加载model的时候可能会因为包版本的不同报错,所以当保存所有模型参数时,需要将模型构造相关代码文件放在相同路径,否则在load的时候无法索引到model的框架。

torch.save(model.state_dict(),model_path)

建议:使用state_dict()模式保存model,torch.save(model.state_dict(),path),这样保存为字典模式,可以直接load。

三、pth转onnx代码

1.使用torch.save(model,'save_path')方式保存

x = torch.randn(1, 3, 224, 224, device=device)

输入测试数据  数据格式[batch, channl, height, width]

model.eval()

不启用 BatchNormalization 和 Dropout,保证BN和dropout不发生变化,pytorch框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层影响结果。

:一定要写上这句话,不然可能会影响onnx的输出结果,经验所知。

import torch
import torch.nn
import onnx# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')model = torch.load('***.pth', map_location=device)
model.eval()input_names = ['input']
output_names = ['output']x = torch.randn(1, 3, 224, 224, device=device)torch.onnx.export(model, x, '***.onnx', input_names=input_names, output_names=output_names, verbose='True')

2.使用torch.save(model.state_dict(),model_path)方式保存

该方式保存需要提供网络结构文件。

import torch.onnx
import onnxruntime as ort
from model import Net# 创建.pth模型model = Net
# 加载权重
model_path = '***.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_statedict = torch.load(model_path, map_location=device)
model.load_state_dict(model_statedict)model.to(device)
model.eval()input_data = torch.randn(1, 3, 224, 224, device=device)# 转化为onnx模型
input_names = ['input']
output_names = ['output']torch.onnx.export(model, input_data, '***.onnx', opset_version=9, verbose=True, input_names=input_names, output_names = output_names)

pth文件转为onnx格式相关推荐

  1. 使用onnx包将pth文件转换为onnx文件

    本文对比一下两种pth文件转为onnx的区别以及onnx文件在NETRON中的图 只有参数的pth文件:cat_dog.pth 既有参数又有模型结构的pth文件:cat_dog_model_args. ...

  2. pythoncsv格式_python实现csv格式文件转为asc格式文件的方法

    一.背景描述 csv格式文件是一种类似于excel的文件格式 asc格式文件是一种可以用text打开的文本文件 csv转asc本来可以用arcgis顺利完成,但由于csv数据量太大(744万行),ar ...

  3. python 读取csv文件转成字符串,python实现csv格式文件转为asc格式文件的方法

    一.背景描述 csv格式文件是一种类似于excel的文件格式 asc格式文件是一种可以用text打开的文本文件 csv转asc本来可以用arcgis顺利完成,但由于csv数据量太大(744万行),ar ...

  4. Python之ffmpeg:利用python编程基于ffmpeg将m4a格式音频文件转为mp3格式文件

    Python之ffmpeg:利用python编程基于ffmpeg将m4a格式音频文件转为mp3格式文件 目录 利用python编程基于ffmpeg将m4a格式音频文件转为mp3格式文件 1.先下载ff ...

  5. ppt生成eps文件_如何将AI/EPS格式文件转为ppt格式

    平时见到很多AI或EPS文件都想直接拿来放到PPT上,对于很多图片来说,只需转PNG就可以导入到PPT了,但对于一些图表,在导入PPT的时候还想要编辑下,这个时候就要用到下面的将AI/EPS格式文件转 ...

  6. 批量将json文件转为jpg格式

    批量将json文件转为jpg格式 1.简介 labelme软件自带有将json文件转为jpg的代码,不过只限单张. 位于./cli/json_to_dataset.py. 2.批量生成代码 impor ...

  7. python3 将eps文件转为jpg格式

    文章目录 1. 按 2. 准备 3. 代码 3.1. 保存成eps格式 3.2. 将eps文件转为jpg格式 1. 按 用Turtle画的图无法直接保存成jpg格式的,只能先保存成eps,再将eps转 ...

  8. python 使用字节流bytes格式读取文件转为int格式,再转为0,1字符串格式

    python 使用字节流bytes格式读取文件转为int格式,再转为0,1字符串格式 with open('test.jpg', 'rb') as src:t = src.read(1) # 读进1B ...

  9. 批量将NC文件转为tif格式

    批量将多年的NC文件转为tif格式进行处理(来源:https://www.geek-share.com/detail/2763962738.html) 所参考和借鉴的文章的链接如下: https:// ...

  10. 利用Python将WEBVTT格式的视频字幕文件转为SRT格式

    1 WebVTT & SRT 格式 WebVTT字幕格式与SRT字幕格式主要区别在于时间格式的区分. 下面是一个WebVTT格式的字幕文件 WEBVTT1 00:00:20.000 --> ...

最新文章

  1. Factory模式与Prototype模式的异同
  2. android源码阅读笔记1-配置源码路径/阅读源码方法讨论
  3. CCNA的全套标准实验
  4. k8s组件说明:ETCD存储组件
  5. 你的行为合理吗?看看社会心理学给我们的启示。
  6. AI:IPPR的数学表示-CNN方法
  7. 在Google Cloud Platform的K8上运行Fn函数
  8. 第二百四十八天 how can I 坚持
  9. jquery跨域请求示例
  10. Linux命令之进程的管理
  11. 工行基于MySQL构建分布式架构的转型之路
  12. java增加内容辅助_Eclipse自定义内容辅助基于默认Java内容辅助结果
  13. JVM 内存分析工具MAT
  14. python仿360界面_高仿360界面的实现(用纯XML和脚本实现)
  15. 拼多多的砍价免费拿是真的吗?
  16. 为什么每次圣斗士出招前都要大喊一下大招?
  17. Halcon九点及旋转标定流程
  18. 用最科学的方法展示最形象的图表——前端数据可视化实践
  19. 北交计算机考研保护一志愿吗,考研er注意了~这些学校不歧视本科!而且保护一志愿!...
  20. 记录Git Unable to negotiate with xxx... 问题

热门文章

  1. hdoj 4747 线段树
  2. 电脑各种故障排除集锦
  3. NOIP 2015 推销员
  4. Symbol 数据类型
  5. 如何使用视频格式转换器将flv转换成MP4
  6. 外贸B2C系列:google企业邮箱设置
  7. 免费企业邮箱: Google企业邮箱的申请
  8. 福州太冷?那就快来这些地方!不仅有威廉王子的蜜月圣地,也有贝克汉姆的度假天堂!全部免签or落地签...
  9. 德国AI“算个球”:西班牙是冠军,只要别让德国进八强(严谨推理)
  10. 考研英语 - word-list-26