Pytorch1.1版本已经提供了相对稳定的c++接口,网上也有了众多的资料供大家参考,进行c++的接口的初步尝试。

可以按照对应的选项下载,下面我们要说的是:

如何利用已经编译好的官方libtorch库和其他的opencv库等联合编写应用?

其实很简单,大概的步骤有三步:

第一步:在python环境下将模型导出为jit的模型

第二步:编写对应的c++ inference 程序。

第三步:直接在VS上(已经成功实验VS2015,高版本的应该也可以)配置相应的libtorch环境,主要是:

dll路径:

PATH=H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib%3bD:\opencv\build\x64\vc14\bin%3b$(PATH)  相应地去修改即可,不需要在PC的path环境下加入libtorch的路径,而是在这里加更加简单。

include路径:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include\torch\csrc\api\include;H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\include;D:\opencv\build\include\opencv2;D:\opencv\build\include\opencv;D:\opencv\build\include;%(AdditionalIncludeDirectories)

主要是加粗线那两个。

注意一定要去掉SDL的检查项,否则会出现错误警告。

lib路径:

H:\deeplearning_framework\Pytorch\libtorch\libtorch-win-shared-with-deps-debug-latest_cpu\libtorch\lib;D:\opencv\build\x64\vc14\lib;%(AdditionalLibraryDirectories)

详细的工程见:https://download.csdn.net/download/xiamentingtao/11486608

这里我们主要改编自:《Win10+VS2017+PyTorch(libtorch) C++ 基本应用》

主要代码参考: https://github.com/zhpmatrix/load-pytorch-model-with-c-

一些 常见的问题:

1. opencv的mat读入libtorch

根据我的实践,这里的最佳写法是:

src = imread(s, cv::IMREAD_COLOR);  //读图// 图像预处理 注意需要和python训练时的预处理一致
int org_w = src.cols;
int org_h = src.rows;torch::Tensor img_tensor = torch::from_blob(src.data, { org_h, org_w,3 }, torch::kByte); //将cv::Mat转成tensor,大小为448,448,3
img_tensor = img_tensor.permute({ 2, 0, 1 });  //调换顺序变为torch输入的格式 3,448,448
img_tensor = img_tensor.toType(torch::kFloat32).div_(255);

注意要先将uint8的图像先读入,再转换成float型。

2. Tensor 转换成cv::Mat

cv::Mat input(img_tensor.size(1), img_tensor.size(2), CV_32FC1, img_tensor.data<float>());

注意这里一定是CV_32FC1而不是CV_32FC3

另外的方式见:https://discuss.pytorch.org/t/convert-torch-tensor-to-cv-mat/42751/2

torch::Tensor out_tensor = module->forward(inputs).toTensor();
assert(out_tensor.device().type() == torch::kCUDA);
out_tensor=out_tensor.squeeze().detach().permute({1,2,0});
out_tensor=out_tensor.mul(255).clamp(0,255).to(torch::kU8);
out_tensor=out_tensor.to(torch::kCPU);
cv::Mat resultImg(512, 512,CV_8UC3);
std::memcpy((void*)resultImg.data,out_tensor.data_ptr(),sizeof(torch::kU8)*out_tensor.numel());

3. model的输出处理

如果只有一个返回值,可以直接转tensor:auto outputs = module->forward(inputs).toTensor();如果有多个返回值,需要先转tuple:auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();

4.Tracing fails because of “parameter sharing”?

看这个案例:https://discuss.pytorch.org/t/help-tracing-fails-because-of-parameter-sharing/40324

其中的部分代码如上,问题就出现在这些画框的地方,主要是这里初始化重复使用了相同的模块进行赋值,例如self.encoder与self.conv1。

解决的办法就是在构造slef.conv1时,对self.encoder[0]加入deepcopy修饰。

即:

from copy import deepcopy
self.conv1 = nn.Sequential(deepcopy(self.encoder[0]),deepcopy(self.relu),deepcopy(self.encoder[2]),deepcopy(self.relu))

参考:https://github.com/pytorch/pytorch/issues/8392#issuecomment-431863763

5. 关于python导出模型的问题

如果训练的pytorch模型保存在cpu上,想在测试时使用gpu模式,则我们需要设置python端保存模型在gpu上,然后才能c++上使用gpu测试。

主要的方法就是:

    checkpoint = torch.load(model_path, map_location="cuda:0")  #very important# create modelmodel = TheModelClass(*args, **kwargs)model.load_state_dict(checkpoint)model.to(device)model.eval()x = torch.rand(1, 3, 448, 448)x = x.to(device)  # very importanttraced_script_module = torch.jit.trace(model.model, x)traced_script_module.save("**.pt")

然后才能在c++上使用gpu模式,方法为:

    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);module->to(at::kCUDA);assert(module != nullptr);std::cout << "ok\n";// 建立一个输入,维度为(1,3,224,224),并移动至cudastd::vector<torch::jit::IValue> inputs;inputs.push_back(torch::ones({1, 3, 224, 224}).to(at::kCUDA));// Execute the model and turn its output into a tensor.at::Tensor output = module->forward(inputs).toTensor();

参考:

pytorch跨设备保存和加载模型(变量类型(cpu/gpu)不匹配原因之一)

https://pytorch.org/tutorials/beginner/saving_loading_models.html

https://blog.csdn.net/IAMoldpan/article/details/85057238

参考文献:

1.利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测

2.Pytorch的C++端(libtorch)在Windows中的使用

3. https://pytorch.org/tutorials/advanced/cpp_frontend.html

4. https://zhpmatrix.github.io/2019/03/01/c++-with-pytorch/

5. Windows使用C++调用Pytorch1.0模型

6. 用cmake构建基于qt5,opencv,libtorch项目

7. c++调用pytorch模型并使用GPU进行预测 (较好的例子)

8. Ptorch 与libTorch 使用过程中问题记录

9. c++ load pytorch 的数据转换

Pytorch的C++接口实践相关推荐

  1. 好书分享——《深度学习框架PyTorch:入门与实践》

    内容简介 : <深度学习框架PyTorch:入门与实践>从多维数组Tensor开始,循序渐进地带领读者了解PyTorch各方面的基础知识.结合基础知识和前沿研究,带领读者从零开始完成几个经 ...

  2. 【视频课】言有三每天答疑,38课深度学习+超60小时分类检测分割数据算法+超15个Pytorch框架使用与实践案例助你攻略CV...

    计算机视觉中大大小小可以包括至少30个以上的方向,在基于深度学习的计算机视觉研究方向中,图像分类,图像分割,目标检测无疑是最基础最底层的任务,掌握好之后可以很快的迁移到其他方向,比如目标识别,目标跟踪 ...

  3. 学习笔记:深度学习(8)——基于PyTorch的BERT应用实践

    学习时间:2022.04.26~2022.04.30 文章目录 7. 基于PyTorch的BERT应用实践 7.1 工具选取 7.2 文本预处理 7.3 使用BERT模型 7.3.1 数据输入及应用预 ...

  4. gRPC amp; Protocol Buffer 构建高性能接口实践

    介绍如何使用 gRPC 和 ProtoBuf,快速了解 gRPC 可以参考这篇文章第一段:gRPC quick Start. 接口开发是软件开发占据举足轻重的地位,是现代软件开发之基石.体现在无论是前 ...

  5. 数据挖据---机器学习平台之H2O架构/接口/实践

    上一章介绍了H2O的使用,这次来学习学习H2O架构接口和实践. 1,H2O架构 关于H2O架构,很多资料也有说明,这里我们一起来看看官网上的介绍. 最上面的是客户层,即接口交互层,H2O支持JavaS ...

  6. 【PyTorch】深度学习实践之CNN高级篇——实现复杂网络

    本文目录 1. 串行的网络结构 2. GoogLeNet 2.1 结构分析 2.2 代码实现 2.3 结果 3. ResNet 3.1 网络分析 3.2 代码实现 3.3 结果 课后练习1:阅读并实现 ...

  7. java对象序列化java.io.Serializable 接口实践

    java.io.Serializable 接口没有任何方法和字段,仅仅表示实现它的类的对象可以被序列化.实现了这个接口的所有类及其子类都可以对象序列化. 序列化前,虚拟机要清楚每个对象的结构,所以序列 ...

  8. WebApiClient百度地图服务接口实践

    1. 文章目的 随着WebApiClient的不断完善,越来越多开发者选择WebApiClient替换原生的HttpClient,然而在应用到实际项目中多多少少会遇到一些项目结合上的疑问和困难,本文将 ...

  9. [PHP]微信红包接口实践说明 CA证书出错 签名错误

    1. 在微信支付的商户平台,在[API安全]中下载API证书,将下载的证书(apiclient _cert.pem/apiclient_key.pem/roota.pem)放在服务器上,确定并记录存放 ...

最新文章

  1. numpy.random.randn()与numpy.random.rand()的区别(转)
  2. java 读取webservice_java 调用webService的各种方法
  3. sql where 1=1和 0=1 的作用
  4. 柴静《认识的人 了解的事》
  5. [禅悟人生]心平气和, 慢慢修行
  6. 软件开源是如何赚钱?
  7. 工具类与工具函数 —— fatal.h
  8. 【BZOJ】3238: [Ahoi2013]差异
  9. Android开发-无法新建Activity及新建后编译错误
  10. EasyUI:Layout 布局
  11. java课堂点名和提问程序_Java程序设计作业.md
  12. java虚无世界_我的世界虚无世界2.5
  13. 为交付Semi卡车做准备 特斯拉招募技术服务人员
  14. js将秒转为天时分秒格式
  15. SpringBoot之整合Redis分析和实现-基于Spring Boot2.0.2版本
  16. Unreal Engine 4 初学者教程:开始
  17. codeforces 76A Gift 最小生成树
  18. COOX培训材料 — PMT(1.Phase)
  19. 震惊!TYPE-C 转OTG(USB2.0传输数据)+PD充电协议芯片
  20. 合工大计算机有哪些好的老师,合肥工业大学计算机与信息学院导师教师师资介绍简介-☆郑利平...

热门文章

  1. FormsAuthentication 和 Session 超时时间不一的问题
  2. React基础学习(第三天)
  3. 我对创业和管理的一些看法
  4. (五)EasyUI使用——datagrid数据表格
  5. IOS开发基础知识--碎片13
  6. 图像连通域标记算法研究
  7. 【java】错误 找不到或无法加载主类
  8. throw()使用小结
  9. codeforces 1045 D. Interstellar battle
  10. 机器学习笔记1(K-近邻算法)