Pytorch的模型文件一般会保存为.pth文件,C++接口一般读取的是.pt文件,因此,C++在调用Pytorch训练好的模型文件的时候就需要进行一个转换,转换为.pt文件,才能够读取。

所以在转换的时候,首先就需要先将模型文件读取进来,然后利用pytorch提供的函数torch.jit.trace进行转换,这个函数的声明为:

def trace(func,
          example_inputs,
          optimize=True,
          check_trace=True,
          check_inputs=None,
          check_tolerance=1e-5,
          _force_outplace=False,
          _module_class=None):
也就是,第一个参数为输入的模型,第二个参数为输入的带测试数据,通常其数据形式要跟模型的输入数据的形式是一样的。

转换的代码例子如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchsummary import summary
 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        self.fc1 = nn.Linear(4*4*64, 512)
        self.fc2 = nn.Linear(512, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
 
model = torch.load("mnist_cnn.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
summary(model, input_size=(1, 28, 28))
model = model.to(device)
traced_script_module = torch.jit.trace(model, torch.ones(1, 1, 28, 28).to(device))
traced_script_module.save("mnist_cnn_cc1.pt")
 
————————————————
版权声明:本文为CSDN博主「熊叫大雄」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/yz2zcx/article/details/100609210

[Pytorch].pth转.pt文件相关推荐

  1. PyTorch参数模型转换为PT模型

    当PyTorch模型需要部署到服务时,为了提升访问速度,需要转换为TRT模型,再进行部署.在转换为TRT模型之前,需要将PyTorch参数模型(如pth.tar)转换为pt模型,使用jit形式.pt模 ...

  2. python中h5文件和pt文件

    python中h5文件和pt文件 h5文件 pt文件 h5文件 h5文件中有两个核心的概念:组"group"和数据集"dataset". 一个h5文件就是 &q ...

  3. 如何打开.pt文件?

    .pt文件是PyTorch中保存模型和数据的文件格式,可以使用PyTorch库中的函数来加载.pt文件.具体可以按照以下步骤进行操作: import torch 使用torch.load()函数加载. ...

  4. 用python将pt文件画图

    单个文件直接画图: # 将pt文件画图展示.已经过测试import torch import matplotlib.pyplot as plt# 从.pt文件中加载数据 data = torch.lo ...

  5. Pytorch GPU版本whl文件安装

    Pytorch GPU版本whl文件安装 安装pytorch的时候,用pip安装时网速实在太慢,换源也不太行,1.2G的文件,一个网络波动就开始疯狂红字.因此使用whl文件进行安装 用whl文件进行安 ...

  6. pytorch .pth模型转tensorflow .pb模型

    训练好的pytorch模型如何转化为tensorflow的pb模型? 本人初步使用的是onnx框架: pytorch --->  onnx ----> tensorflow 使用onnx转 ...

  7. 从Pytorch源码看.pt文件

    Pytorch中张量的保存与加载 保存张量 在Pytorch中,一个约定俗成的方法是使用.pt扩展的文件格式来保存张量,使用的方法为torch.save(). 函数原型与参数说明 import tor ...

  8. 【笔记】pth、pt、pkl的区别:pt 常做数据集的数据存储形式

    调试PyTorch代码保存训练模型的时候有些时候保存的格式是 .pt,有些时候是.pth与.pkl,不禁好奇它们之间的区别 问题描述         我们经常会看到后缀名为.pt,.pth,.pkl的 ...

  9. 【小白学PyTorch】17.TFrec文件的创建与读取

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分 ...

最新文章

  1. 全国大学智能车竞赛完全模型组中的赛道标志
  2. html调试和js脚本调试
  3. 非spring环境中配置文件工具
  4. apache配置反向代理(通过不同端口访问不同目录)
  5. 去死吧!USB转串口!!!
  6. 7个C语言小程序让你快速入门程序世界
  7. java基础之线程(1)
  8. linux运行不了.sh文件,linux下不能执行/bin/sh脚本的原因:command not found
  9. linux环境apache,php的安装目录
  10. 计算机六级好考吗,计算机六级考什么?
  11. 18岁智商低的表现_孩子反应慢并不是智商低,三个原因很关键,第一个跟父母有关...
  12. 智慧城轨信息技术架构及信息安全规范_会员信息 | 中国铁设:在深圳,我们打造智慧地铁的“最强大脑”...
  13. HTML+CSS+JavaScript网页特效源代码(复制代码保存即可使用)
  14. PLC控制系统设计的基本内容
  15. python根据excel数据生成柱状图并导出成图片格式
  16. (翻译)社会认同模式(Social proof)
  17. 电影天堂二级页面抓取案例
  18. 服务器隔离虚拟机,筛选Hyper-V提供的虚拟机隔离选项
  19. Linux下Vim的常用命令操作大全
  20. 设计一可控同步四进制可逆计数器, 其由输入X1,X2控制, 用D触发器和74153及必要的门电路实现

热门文章

  1. boost::type_traits模块用法的一些示例
  2. boost::math模块计算二项式随机变量的概率和分位数的简单示例
  3. boost::boost::maximum_adjacency_search用法的测试程序
  4. boost::graph模块实现内部pmap捆绑的测试程序
  5. ITK:将BinaryMorphologicalClosingFilter应用于给定LabelMap的一个LabelObject
  6. OpenCV cv::split用法的实例(附完整代码)
  7. QDoc特殊内容special content
  8. php 内容转换dom,php – 防止DOMDocument :: loadHTML()转换实体
  9. ajax+php跨域请求数据库,基于jQuery的ajax跨域请求,PHP作为服务器端代码
  10. java快递下单模块,Java开发快递物流项目(7)