pytorch转为onnx格式:

def Torch2Onnx(model,input_size,output_name,istrained=True):''':param: model:param: input_size .e.t. (244,244):param: output_name .e.t. "test_output":param: if convert a trained model or not. default: True'''x = Variable(torch.randn(1,3,input_size[0],input_size[1])).cuda()if istrained:torch_out = torch.onnx.export(model,x,output_name,verbose=True)else:torch_out = torch.onnx.export(model,x,output_name,export_params=False,verbose=True) # Only export a untrained model.

使用举例:

model = model()
model.load_state_dict(torch.load(weight_path))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_size = (384,288)
Torch2Onnx(model,input_size,"test.onnx")

获取model中的params:

请注意:不同的方法默认model在cpu还是在cuda上是不一样的,如果出现类似RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same的报错,请检查weight是否应该在cuda上。

方法一:使用torchsummary

  1. 使用pip安装torchsummary:
    pip install torchsummary

  2. 代码片段:

    from torchsummary import summary
    model = model()
    model.load_state_dict(torch.load(weight_path))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    summary(model,(3,384,288))
    

方法二:使用torchstat

  1. 使用pip安装torchstat:
    pip install torchstat

  2. 代码片段(和summary差不多)

    from torchstat import statmodel = model()model.load_state_dict(torch.load(weight_path))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")stat(model,(3,384,288))
    

方法三:使用thop(不太推荐)

  1. 使用pip安装thop:
    pip install thop

  2. 代码片段:

    from thop import profile,clever_format
    model = model()model.load_state_dict(torch.load(weight_path))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")flops, params = profile(model,inputs=())flops,params = clever_format(flops,params,"%.3f")
    

pytorch转为onnx格式,以及加载模型的params和GFLOPs方法相关推荐

  1. mxnet加载模型的params和json文件来预测

    导读 有时候我们在使用别人的mxnet预训练模型时,会有两个文件params和json文件,其中params文件中包含的是模型的网络参数,json文件包含的是网络的结构.这里我们以ImageNet的预 ...

  2. Pytorch加载模型并进行图像分类预测

    目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...

  3. PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features..,Expected .

    希望将训练好的模型加载到新的网络上.如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题. Unexpected key(s) in state_dict: "mod ...

  4. PyTorch 保存模型结构参数及加载模型

    PyTorch 保存模型结构参数及加载模型 保存模型与加载 保存模型分为两种方式: 保存整个网络结构和参数 保存整个网络的参数 # 1.保存并加载整个网络结构和参数 # 保存模型 torch.save ...

  5. pytorch 保存、加载模型

    一般保存为.pt格式,保存模型使用: torch.save(model, '保存位置') 加载模型使用: model_load = torch.load('加载模型的位置') 完整代码 import ...

  6. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

  7. pytorch model.to(device) 加载模型特别慢

    问题:pytorch model.to(device) 加载模型特别慢 解决方案:卸载掉conda安装的pytorch 采用pytorch官网的pip指令下载方式.

  8. pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1

    文章目录 背景 报错 原因 解决 背景 Pytorch在加载模型参数的时候,有两种情况可能出现这种问题: 自己写的网络结构,例如: 代码 import models arch = 'resnet50' ...

  9. Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率

    前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...

  10. pytorch加载模型时出现.....ckpt_100.pth is a zip archive (did you mean to use torch.jit.load()?)

    在测试加载训练好的模型时出现上方问题,参考这篇文章,原因是训练和测试的torch版本不一致. 训练的时候是1.6,测试的时候是1.2,因此需要先在1.6版本下加载模型,重新保存,在保存的时候设置use ...

最新文章

  1. 任意长度的字典生成算法
  2. Https 与 SSl证书 概要
  3. 世界上最美的40个小镇,每一个都犹如仙境!
  4. 打基础一定要吃透这12类 Python 内置函数
  5. [ImportNew]Java线程面试题
  6. 如何通过JS获取元素宽高
  7. 在WinForm程序中读写系统配置
  8. linux内核简介和进程管理
  9. GML C++ Camera Calibration Toolbox 相机标定畸变矫正
  10. lumen安装后输出hello world
  11. 【WIN10】清除图标缓存
  12. Python:科赫曲线绘制雪花
  13. 【JS】常用效果总结
  14. python连接sftp下载文件及文件夹
  15. 【软件测试】:“用户登录”功能测试用例设计方法
  16. 《剑与电——角色扮演游戏设计艺术》读书笔记(二)
  17. VSCode报错“gcc不是内部或外部命令......”(自用)
  18. 关于SSD HMB与CMB
  19. 关于java实例方法可以访问类变量的一种解释
  20. windows10 安装msdatlst.ocx控件

热门文章

  1. Excel怎么隐藏指定文本单元格整行
  2. 基于 smart-config技术实现
  3. scratch编程钟表
  4. SpringBoot实现Excel导入导出,好用到爆,POI可以扔掉了
  5. 【ps-course】layer 图层
  6. 网页无法复制文字,一个插件解决问题!!!!
  7. Java集成建行龙支付接口(详细)
  8. 在python中的占位符中、请你选出不属于占位符的选项_2020年超星尔雅微表情识别·读脸读心 作业答案...
  9. java:中文汉语数字和阿拉伯数字互相转换,人民币大小写转换
  10. php 登陆微博,用新浪微博账号登录(第三方登录)