pytorch转为onnx格式,以及加载模型的params和GFLOPs方法
pytorch转为onnx格式:
def Torch2Onnx(model,input_size,output_name,istrained=True):''':param: model:param: input_size .e.t. (244,244):param: output_name .e.t. "test_output":param: if convert a trained model or not. default: True'''x = Variable(torch.randn(1,3,input_size[0],input_size[1])).cuda()if istrained:torch_out = torch.onnx.export(model,x,output_name,verbose=True)else:torch_out = torch.onnx.export(model,x,output_name,export_params=False,verbose=True) # Only export a untrained model.
使用举例:
model = model()
model.load_state_dict(torch.load(weight_path))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_size = (384,288)
Torch2Onnx(model,input_size,"test.onnx")
获取model中的params:
请注意:不同的方法默认model在cpu还是在cuda上是不一样的,如果出现类似RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
的报错,请检查weight是否应该在cuda上。
方法一:使用torchsummary
使用pip安装torchsummary:
pip install torchsummary
代码片段:
from torchsummary import summary model = model() model.load_state_dict(torch.load(weight_path)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) summary(model,(3,384,288))
方法二:使用torchstat
使用pip安装torchstat:
pip install torchstat
代码片段(和summary差不多)
from torchstat import statmodel = model()model.load_state_dict(torch.load(weight_path))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")stat(model,(3,384,288))
方法三:使用thop(不太推荐)
使用pip安装thop:
pip install thop
代码片段:
from thop import profile,clever_format model = model()model.load_state_dict(torch.load(weight_path))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")flops, params = profile(model,inputs=())flops,params = clever_format(flops,params,"%.3f")
pytorch转为onnx格式,以及加载模型的params和GFLOPs方法相关推荐
- mxnet加载模型的params和json文件来预测
导读 有时候我们在使用别人的mxnet预训练模型时,会有两个文件params和json文件,其中params文件中包含的是模型的网络参数,json文件包含的是网络的结构.这里我们以ImageNet的预 ...
- Pytorch加载模型并进行图像分类预测
目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...
- PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features..,Expected .
希望将训练好的模型加载到新的网络上.如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题. Unexpected key(s) in state_dict: "mod ...
- PyTorch 保存模型结构参数及加载模型
PyTorch 保存模型结构参数及加载模型 保存模型与加载 保存模型分为两种方式: 保存整个网络结构和参数 保存整个网络的参数 # 1.保存并加载整个网络结构和参数 # 保存模型 torch.save ...
- pytorch 保存、加载模型
一般保存为.pt格式,保存模型使用: torch.save(model, '保存位置') 加载模型使用: model_load = torch.load('加载模型的位置') 完整代码 import ...
- Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法
需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...
- pytorch model.to(device) 加载模型特别慢
问题:pytorch model.to(device) 加载模型特别慢 解决方案:卸载掉conda安装的pytorch 采用pytorch官网的pip指令下载方式.
- pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1
文章目录 背景 报错 原因 解决 背景 Pytorch在加载模型参数的时候,有两种情况可能出现这种问题: 自己写的网络结构,例如: 代码 import models arch = 'resnet50' ...
- Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率
前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...
- pytorch加载模型时出现.....ckpt_100.pth is a zip archive (did you mean to use torch.jit.load()?)
在测试加载训练好的模型时出现上方问题,参考这篇文章,原因是训练和测试的torch版本不一致. 训练的时候是1.6,测试的时候是1.2,因此需要先在1.6版本下加载模型,重新保存,在保存的时候设置use ...
最新文章
- 任意长度的字典生成算法
- Https 与 SSl证书 概要
- 世界上最美的40个小镇,每一个都犹如仙境!
- 打基础一定要吃透这12类 Python 内置函数
- [ImportNew]Java线程面试题
- 如何通过JS获取元素宽高
- 在WinForm程序中读写系统配置
- linux内核简介和进程管理
- GML C++ Camera Calibration Toolbox 相机标定畸变矫正
- lumen安装后输出hello world
- 【WIN10】清除图标缓存
- Python:科赫曲线绘制雪花
- 【JS】常用效果总结
- python连接sftp下载文件及文件夹
- 【软件测试】:“用户登录”功能测试用例设计方法
- 《剑与电——角色扮演游戏设计艺术》读书笔记(二)
- VSCode报错“gcc不是内部或外部命令......”(自用)
- 关于SSD HMB与CMB
- 关于java实例方法可以访问类变量的一种解释
- windows10 安装msdatlst.ocx控件
热门文章
- Excel怎么隐藏指定文本单元格整行
- 基于 smart-config技术实现
- scratch编程钟表
- SpringBoot实现Excel导入导出,好用到爆,POI可以扔掉了
- 【ps-course】layer 图层
- 网页无法复制文字,一个插件解决问题!!!!
- Java集成建行龙支付接口(详细)
- 在python中的占位符中、请你选出不属于占位符的选项_2020年超星尔雅微表情识别·读脸读心 作业答案...
- java:中文汉语数字和阿拉伯数字互相转换,人民币大小写转换
- php 登陆微博,用新浪微博账号登录(第三方登录)