[LibTorch] C++ 调用 PyTorch 导出的模型
参考文章
- C++部署pytorch模型
- 利用LibTorch部署PyTorch模型
- 官方文档
问题
pytorch 的神经网络模型有很多,但 libtorch 就特别少。现在面临的问题是要在 C++ 环境下应用神经网络模型,肯定不能直接使用 pytorch 模型。解决办法有两个:
方法一是用 TorchScript 工具导出模型 poolnet.pt,模型中包含网络结构和参数权重,因此可以直接在 C++ 里面生成神经网络。
方法二是用 C++ 复现网络结构,封装为为类对象,再从 poolnet.pt 中导入参数权重。
对于神经网络模型 PoolNet ,将其应用到 C++ 环境下进行视频处理,下面这是前 10 帧画面处理时间。明显看出,方法一前两次运行时间很长,从第三帧开始,两种方法的处理时间几乎相同。但是,方法一相当简单,导出模型即可,方法二需要复现网络结构,工程量巨大。下面重点介绍方法一。
TorchScript 工具介绍
必定要看 官方文档。上面介绍了 trace 和 script 的区别。
PyTorch 导出模型
resnet50
编辑 export.py 文件,以 pytorch 提供的 resnet50 为例,分别使用 trace 和 script 导出模型。trace 需要提供一个输入样例,script 则不需要。但是复杂的模型使用 script 一般会失败,但 trace 可以。trace 和 script 导出的模型几乎没有区别,缺点是前两次处理时间都格外久。
import torch
from torchvision.models import resnet50net = resnet50(pretrained=True)
net = net.cuda()
net.eval()
for key, value in net.named_parameters(): print(key)# trace
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
traced_module = torch.jit.trace(net, x)
traced_module.save("resnet50_trace.pt")# script
scripted_module = torch.jit.script(net)
scripted_module.save("resnet50_script.pt")
在 python3+pytorch 的虚拟环境下执行
python export.py
net
上面的例子是 pytorch 提供的 resnet50,如果是自己写的模型,可以按照下面的方式来。其中,net.pth 是训练后保存的参数,net.pt 则是期望导出的模型,使用 trace 方法。
import torch
import torchvision# 初始化神经网络
net = Net()
net.load_state_dict(torch.load("net.pth"))
net.cuda()
net.eval()# 导出模型
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
m = torch.jit.trace(net, x)
m.save("net.pt")
其中
m.save("net.pt")
也可写为
torch.jit.save(m, "net.pt")
C++ 中调用模型
C++ 使用 libtorch 时,一般使用 CMake 进行管理(参考 Pytorch 官网教程)。下面是在 C++ 环境中调用模型的方法。
#include <torch/torch.h>
#include <torch/script.h>torch::Device device(torch::kCUDA);
// image.rows, image.cols 高在前,宽在后
torch::Tensor img_tensor = torch::from_blob(img.data, {1, image.rows, image.cols, 3}, torch::kByte).to(device);
img_tensor = img_tensor.permute({0, 3, 1, 2});
img_tensor = img_tensor.toType(torch::kFloat);
img_tensor = img_tensor.div(255.0);
torch::jit::script::Module net = torch::jit::load("../models/net.pt");
// 打印模型中的参数
for (const auto& pair : net.named_parameters()) {std::cout << pair.name << " " << pair.value.requires_grad() << std::endl;
}
net.to(device)
torch::NoGradGuard no_grad;
torch::Tensor output = net.forward({img_tensor}).toTensor();
[LibTorch] C++ 调用 PyTorch 导出的模型相关推荐
- PyTorch导出JIT模型并用C++ API libtorch调用
PyTorch导出JIT模型并用C++ API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型 ...
- 搭建C++开发图像算法的环境——利用C++调用Pytorch训练后模型
本文主要介绍如何搭建C++开发图像算法的环境,使用到CMake + libtorch + OpenCV + ITK等.旨在构建一个可融合深度学习框架,可开发图像处理算法且易于跨平台编译的环境. 准备条 ...
- C++调用PyTorch模型:LibTorch
转载的文章,挺不错,学习一下! LibTorch学习笔记(一) 前天由于某些原因需要利用C++调用PyTorch,于是接触到了LibTorch,配了两天最终有了一定的效果,于是记录一下. 环境 PyT ...
- Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它
Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它 本教程我们将描述如何将PyTorch中定义的模型转换为ONNX格式,然后使用ONNX运行时运行它. ONNX运行时是一个 ...
- win10 c++调用pytorch模型
1.pytorch模型生成pt模型 """Export a pth model to TorchScript formatsimport time import torc ...
- 【pytorch】将模型部署至生产环境:借助TorchScript跟踪法及注释法生成可供C++调用的模块
(一)思路简介 1.pyTorch会提供TorchScript,它可以生成最有效的C++可读的模型格式. 2.TorchScript的记录法支持Python控制流的子集,由TorchScript创建的 ...
- pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别
学更好的别人, 做更好的自己. --<微卡智享> 本文长度为2548字,预计阅读8分钟 前言 前三章介绍了pyTorch训练的相关,我们也保存模型成功了,今天这篇就是使用C++ OpenC ...
- Pytorch学习 - 保存模型和重新加载
Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...
- 【pytorch速成】Pytorch图像分类从模型自定义到测试
文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...
最新文章
- SpringBoot02_构建rest工程完成第一个controller类
- videoview 播放视频
- HDU 6136 Death Podracing (堆)
- 超图理论的一点理解(一)
- (转)windows下安装python及第三方库numpy、scipy、matplotlib终极版
- Azure认知服务之表单识别器
- sqlmap使用_sqlmap于sql labs下使用
- php 魔术方法使用说明详细
- 5分钟学会如何玩转云数据库组件(迁移,审计,订阅)
- 2016-2017-2 20155309 南皓芯《java程序设计》第八周学习总结
- Python实现基于TF-IDF抽取文本数据关键词
- SOUI知识点小结2
- matlab 平滑曲线连接_曲线拟合的一些想法
- 开源项目工时系统_GitHub - fjp203/timemaker: 定额工时管理系统
- opc服务器网站,OPC 服务器
- 虚拟机访问本地mysql_本地访问虚拟机oracle数据库的尝试
- 【upc】Water Testing 皮克定理+多边形面积公式
- CentOS7使用mount命令来挂载CDROM
- android gms认证之run host test,Android GMS认证项总结
- 第三十四篇 源极跟随器
热门文章
- 海岸鸿蒙甲醇中8种苯系物,甲醇中8种苯系物混合溶液标准物质-8种VOC
- w25q64 linux,W25Q64Flash芯片STM32操作
- linux命令cd回退_Linux命令一
- vscode 翻译_前端新手 VSCode 入门指南
- 精通开关电源设计第三版pdf_看漫画,学电源(一)丨线性电源与开关电源的构造...
- LeetCode Week 3:第 21 ~ 30 题
- I - Arbitrage(判断是否有无正环 II)
- 算法导论 思考题6-2
- string的operate+=
- 目标检测——如何让模型过拟合