参考文章

  • C++部署pytorch模型
  • 利用LibTorch部署PyTorch模型
  • 官方文档

问题

pytorch 的神经网络模型有很多,但 libtorch 就特别少。现在面临的问题是要在 C++ 环境下应用神经网络模型,肯定不能直接使用 pytorch 模型。解决办法有两个:

  • 方法一是用 TorchScript 工具导出模型 poolnet.pt,模型中包含网络结构参数权重,因此可以直接在 C++ 里面生成神经网络。

  • 方法二是用 C++ 复现网络结构,封装为为类对象,再从 poolnet.pt 中导入参数权重

对于神经网络模型 PoolNet ,将其应用到 C++ 环境下进行视频处理,下面这是前 10 帧画面处理时间。明显看出,方法一前两次运行时间很长,从第三帧开始,两种方法的处理时间几乎相同。但是,方法一相当简单,导出模型即可,方法二需要复现网络结构,工程量巨大。下面重点介绍方法一

TorchScript 工具介绍

必定要看 官方文档。上面介绍了 trace 和 script 的区别。

PyTorch 导出模型

resnet50

编辑 export.py 文件,以 pytorch 提供的 resnet50 为例,分别使用 trace 和 script 导出模型。trace 需要提供一个输入样例,script 则不需要。但是复杂的模型使用 script 一般会失败,但 trace 可以。trace 和 script 导出的模型几乎没有区别,缺点是前两次处理时间都格外久

import torch
from torchvision.models import resnet50net = resnet50(pretrained=True)
net = net.cuda()
net.eval()
for key, value in net.named_parameters(): print(key)# trace
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
traced_module = torch.jit.trace(net, x)
traced_module.save("resnet50_trace.pt")# script
scripted_module = torch.jit.script(net)
scripted_module.save("resnet50_script.pt")

在 python3+pytorch 的虚拟环境下执行

python export.py

net

上面的例子是 pytorch 提供的 resnet50,如果是自己写的模型,可以按照下面的方式来。其中,net.pth 是训练后保存的参数,net.pt 则是期望导出的模型,使用 trace 方法。

import torch
import torchvision# 初始化神经网络
net = Net()
net.load_state_dict(torch.load("net.pth"))
net.cuda()
net.eval()# 导出模型
x = torch.ones(1, 3, 224, 224)
x = x.cuda()
m = torch.jit.trace(net, x)
m.save("net.pt")

其中

m.save("net.pt")

也可写为

torch.jit.save(m, "net.pt")

C++ 中调用模型

C++ 使用 libtorch 时,一般使用 CMake 进行管理(参考 Pytorch 官网教程)。下面是在 C++ 环境中调用模型的方法。

#include <torch/torch.h>
#include <torch/script.h>torch::Device device(torch::kCUDA);
// image.rows, image.cols 高在前,宽在后
torch::Tensor img_tensor = torch::from_blob(img.data, {1, image.rows, image.cols, 3}, torch::kByte).to(device);
img_tensor = img_tensor.permute({0, 3, 1, 2});
img_tensor = img_tensor.toType(torch::kFloat);
img_tensor = img_tensor.div(255.0);
torch::jit::script::Module net = torch::jit::load("../models/net.pt");
// 打印模型中的参数
for (const auto& pair : net.named_parameters()) {std::cout << pair.name << " " << pair.value.requires_grad() << std::endl;
}
net.to(device)
torch::NoGradGuard no_grad;
torch::Tensor output = net.forward({img_tensor}).toTensor();

[LibTorch] C++ 调用 PyTorch 导出的模型相关推荐

  1. PyTorch导出JIT模型并用C++ API libtorch调用

    PyTorch导出JIT模型并用C++ API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型 ...

  2. 搭建C++开发图像算法的环境——利用C++调用Pytorch训练后模型

    本文主要介绍如何搭建C++开发图像算法的环境,使用到CMake + libtorch + OpenCV + ITK等.旨在构建一个可融合深度学习框架,可开发图像处理算法且易于跨平台编译的环境. 准备条 ...

  3. C++调用PyTorch模型:LibTorch

    转载的文章,挺不错,学习一下! LibTorch学习笔记(一) 前天由于某些原因需要利用C++调用PyTorch,于是接触到了LibTorch,配了两天最终有了一定的效果,于是记录一下. 环境 PyT ...

  4. Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它

    Python: 从PYTORCH导出模型到ONNX,并使用ONNX运行时运行它 本教程我们将描述如何将PyTorch中定义的模型转换为ONNX格式,然后使用ONNX运行时运行它. ONNX运行时是一个 ...

  5. win10 c++调用pytorch模型

    1.pytorch模型生成pt模型 """Export a pth model to TorchScript formatsimport time import torc ...

  6. 【pytorch】将模型部署至生产环境:借助TorchScript跟踪法及注释法生成可供C++调用的模块

    (一)思路简介 1.pyTorch会提供TorchScript,它可以生成最有效的C++可读的模型格式. 2.TorchScript的记录法支持Python控制流的子集,由TorchScript创建的 ...

  7. pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别

    学更好的别人, 做更好的自己. --<微卡智享> 本文长度为2548字,预计阅读8分钟 前言 前三章介绍了pyTorch训练的相关,我们也保存模型成功了,今天这篇就是使用C++ OpenC ...

  8. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

  9. 【pytorch速成】Pytorch图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...

最新文章

  1. SpringBoot02_构建rest工程完成第一个controller类
  2. videoview 播放视频
  3. HDU 6136 Death Podracing (堆)
  4. 超图理论的一点理解(一)
  5. (转)windows下安装python及第三方库numpy、scipy、matplotlib终极版
  6. Azure认知服务之表单识别器
  7. sqlmap使用_sqlmap于sql labs下使用
  8. php 魔术方法使用说明详细
  9. 5分钟学会如何玩转云数据库组件(迁移,审计,订阅)
  10. 2016-2017-2 20155309 南皓芯《java程序设计》第八周学习总结
  11. Python实现基于TF-IDF抽取文本数据关键词
  12. SOUI知识点小结2
  13. matlab 平滑曲线连接_曲线拟合的一些想法
  14. 开源项目工时系统_GitHub - fjp203/timemaker: 定额工时管理系统
  15. opc服务器网站,OPC 服务器
  16. 虚拟机访问本地mysql_本地访问虚拟机oracle数据库的尝试
  17. 【upc】Water Testing 皮克定理+多边形面积公式
  18. CentOS7使用mount命令来挂载CDROM
  19. android gms认证之run host test,Android GMS认证项总结
  20. 第三十四篇 源极跟随器

热门文章

  1. 海岸鸿蒙甲醇中8种苯系物,甲醇中8种苯系物混合溶液标准物质-8种VOC
  2. w25q64 linux,W25Q64Flash芯片STM32操作
  3. linux命令cd回退_Linux命令一
  4. vscode 翻译_前端新手 VSCode 入门指南
  5. 精通开关电源设计第三版pdf_看漫画,学电源(一)丨线性电源与开关电源的构造...
  6. LeetCode Week 3:第 21 ~ 30 题
  7. I - Arbitrage(判断是否有无正环 II)
  8. 算法导论 思考题6-2
  9. string的operate+=
  10. 目标检测——如何让模型过拟合