前面已经记录过了,流程就是这么个流程:

  • 配置libtorch --->python训练的模型怎么在C++使用?_己亥谷雨-CSDN博客
  • pytorch模型转化
  • 编写C++调用程序

这里就来记录一下模型转化和C++调用程序。

关于模型的训练就不多说了,不会训练模型还扯什么调用。

好在之前记录过一点:


========================分割线============================

好。我现在有了一个训练好的模型了,这个模型就是随意跑了一两个epoch,不考虑准确率,这里只考虑通整个流程 。

最初的模型就是 cnn.pth,我直接读这个模型好像不大行,于是还是将其按照文档那样,转成cnn.pt。怎么转的呢?教程就是这样的

import torch
import torchvision# An instance of your model.
model = torchvision.models.resnet18()# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_resnet_model.pt")

他这个例子是加载的已有的模型库,然后随意给了一个输入,让这个输入进到模型里面打个样,观察一下地形(我的理解),然后保存为pt文件。

那我测的话是自己的模型,怎么搞呢。和这个例子差不多,只是模型不一样,获得model后,就一样了

class VGG16(nn.Module):def __init__(self, num_classes=10):super(VGG16, self).__init__()self.features = nn.Sequential(# 1nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(True),# 2nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),nn.ReLU(True),nn.MaxPool2d(kernel_size=2, stride=2),# 3nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(True),# 4nn.Conv2d(128, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),nn.ReLU(True),nn.MaxPool2d(kernel_size=2, stride=2),# 5nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256),nn.ReLU(True),# 6nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256),nn.ReLU(True),# 7nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.BatchNorm2d(256),nn.ReLU(True),nn.MaxPool2d(kernel_size=2, stride=2),# 8nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.ReLU(True),# 9nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.ReLU(True),# 10nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.ReLU(True),nn.MaxPool2d(kernel_size=2, stride=2),# 11nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.ReLU(True),# 12nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.ReLU(True),# 13nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.BatchNorm2d(512),nn.ReLU(True),nn.MaxPool2d(kernel_size=2, stride=2),nn.AvgPool2d(kernel_size=1, stride=1),)self.classifier = nn.Sequential(# 14nn.Linear(512, 4096),nn.ReLU(True),nn.Dropout(),# 15nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),# 16nn.Linear(4096, num_classes),)# self.classifier = nn.Linear(512, 10)def forward(self, x):out = self.features(x)#        print(out.shape)out = out.view(out.size(0), -1)#        print(out.shape)out = self.classifier(out)#        print(out.shape)return out'''创建model实例对象,并检测是否支持使用GPU'''
model = VGG16()use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:model = model.cuda()model.eval()'''测试'''
classes=('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 转换模型
model.load_state_dict(torch.load("./cnn.pth"))
torch.no_grad()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 32, 32)if use_gpu:example = Variable(example).cuda()# label = Variable(label, volatile=True).cuda()
else:example = Variable(example)# label = Variable(label)# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("cnn.pt")

上面就是我测试的,先把我之前训练好的pth模型读进去,这样就有了model,之后就和例子一样了。不过这里加了个GPU的检测,我是用的GPU。好,到这里就将pth文件转换为了pt文件。

【当然,我相信这肯定是个笨方法,一定有更简单的,比如我训练好之后直接就保存成pt或是其他方法,暂时先不管】

下一步就是编写C++调用代码了

#include "torch/script.h" // One-stop header.
#include <iostream>
#include <opencv2\opencv.hpp>
#include <opencv2\imgproc\types_c.h>
using namespace cv;
using namespace std;int main(int argc, const char* argv[])
{/*******load*********/if (argc != 2) {std::cerr << "usage: example-app <path-to-exported-script-module>\n";return -1;}torch::DeviceType device_type;device_type = torch::kCPU;//这里我没有检测了,直接用CPU做推理torch::Device device(device_type);torch::jit::script::Module module;//std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1], device);try {// Deserialize the scriptmodule from a file using torch::jit::load().module = torch::jit::load(argv[1], device);//这里一定要加device,不然加载失败//module = torch::jit::load("cnn.pt", device);}catch (const c10::Error& e) {std::cerr << "error loading the model\n";return -1;}vector<string> out_list = { "plane", "ca", "bird", "cat","deer", "dog", "frog", "horse", "ship", "truck" };auto image = imread("dog3.jpg");if (!image.data){cout << "image imread failed" << endl;}cvtColor(image, image, CV_BGR2RGB);Mat img_transfomed;resize(image, img_transfomed, Size(32, 32));/*cout << img_transfomed.data;*///img_transfomed.convertTo(img_transfomed, CV_16FC3, 1.0f / 255.0f);  //Mat to tensor,   torch::Tensor tensor_image = torch::from_blob(img_transfomed.data, { img_transfomed.rows, img_transfomed.cols, img_transfomed.channels() }, torch::kByte);tensor_image = tensor_image.permute({ 2, 0, 1 });tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255);tensor_image = tensor_image.unsqueeze(0);//增加一维,拓展维度,在最前面std::vector<torch::jit::IValue> inputs;inputs.push_back(tensor_image);torch::Tensor output = module.forward(inputs).toTensor();torch::Tensor output_max = output.argmax(1);int a = output_max.item().toInt();cout << "分类预测的结果为:"<< out_list[a] << endl;return 0;//下面是输出为图像的例子tensor to Mat  //output_max = output_max.squeeze();//output_max = output_max.mul(255).to(torch::kU8);//output_max = output_max.to(torch::kCPU);//Mat result_img(Size(480, 320), CV_8UC1);//memcpy((void*)result_img.data, output_max.data_ptr(), sizeof(torch::kU8) * output_max.numel());//imshow("result", result_img);//imwrite("result.bmp", result_img);//system("pause");
}

这里把模型名字写到命令行参数里

然后就可以出结果了(这个“然后”。。。其实我踩了挺多坑,后面慢慢再记录吧)

还有一个事情要做,这里的结果和python下跑的结果是一样的吗?

对比一下python下相同图片的结果

C++下和python下的结果一致,但都是错的,因为我用的测试图是小狗的。。。



没测试之前总感觉会比较复杂,搞不定。但是很多事情都是自己唬自己,试一下呢?说不定会发现很简单,或者不过如此之类的。(当然不是说这个简单了,我还是感觉挺复杂的,只是这个小测试还是比较简单的)

最后贴个搞笑的图片,哈哈哈哈哈哈哈哈笑死了,别人的标题还真学不来

C++调pytorch模型的全过程记录相关推荐

  1. 基于C++的PyTorch模型部署

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...

  2. 学习记录——Pytorch模型移植Android小例子

    提示:注意文章时效性,2022.04.02. 目录 前言 零.使用的环境 一.模型准备 1.导出模型 2.错误记录 2.1要载入完整模型(网络结构+权重参数) 2.2导出的模型文件格式 二.Andro ...

  3. Python Apex YOLO V7 main 目标检测 全过程记录

    博文目录 文章目录 环境准备 YOLO V7 main 分支 TensorRT 环境 工程源码 假人权重文件 toolkit.py 测试.实时检测.py grab.for.apex.py label. ...

  4. TensorRT和PyTorch模型的故事

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨伯恩legacy 来源丨https://zhuanlan.zh ...

  5. 在C++平台上部署PyTorch模型流程+踩坑实录

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文主要讲解如何将pytorch的模型部署到c++平台上的模 ...

  6. 如何使用TensorRT对训练好的PyTorch模型进行加速?

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨伯恩legacy@知乎 来源丨https://zhuanlan.zhihu.com/p/8831 ...

  7. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

  8. Intel发布神经网络压缩库Distiller:快速利用前沿算法压缩PyTorch模型

    Intel发布神经网络压缩库Distiller:快速利用前沿算法压缩PyTorch模型 原文:https://blog.csdn.net/u011808673/article/details/8079 ...

  9. 经验 | 在C++平台上部署PyTorch模型流程+踩坑实录

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨火星少女@知乎 来源丨https://zhuanlan ...

最新文章

  1. mysql show schema_快速入门 · xiaoboluo768/mysql-system-schema Wiki · GitHub
  2. unslider制作轮播图
  3. [PAT乙级]1033 旧键盘打字(getline()读入)
  4. 暴力修改SElinux权限
  5. Spring Cloud Alibaba —— Sentinel 详细使用
  6. MyBatis——13用mybatis实现银行转账
  7. realtek是什么意思_Realtek高清晰音频管理器 全解析
  8. 提高电脑开机速度的方法
  9. 一个架构师谈什么是架构以及怎么成为一个架构师
  10. 地理坐标系:WGS84和BD09互转
  11. M3U8视频解密下载
  12. ue4 中动画控制,利用conduit节点
  13. 欧姆龙485通讯示例程序_PLC程序结构设计和技巧
  14. SD卡、记忆棒等内存卡的数据恢复方法
  15. html里怎么旋转视频文件,如何旋转视频文件(方法三)
  16. /给你一个由 n 个整数组成的数组 nums ,和一个目标值 target 。请你找出并返回满足下述全部条件且不重复的四元组 [nums[a], nums[b], nums[c], nums[
  17. Kubernetes安装EFK日志收集
  18. (附源码)计算机毕业设计SSM快递代收系统
  19. 币圈的8大女神都是谁?-千氪
  20. 基于51单片机的万年历可显示农历带闹钟整点报送功能proteus仿真原理图PCB

热门文章

  1. 贝壳找房技术总监肖鹏:高速成长下的技术团队怎么带?
  2. 爱奇艺数据中台负责人马金韬:数据中台建设与应用
  3. TortoiseGit上传代码报错error:1407742E
  4. Java中主线程如何捕获子线程抛出 ...
  5. 深入掌握JMS(四):实战Queue
  6. JSF 源代码赏析之FacesServlet
  7. 平均负载及CPU上下文切换
  8. 简述 JavaScript 作用域与词法分析
  9. iOS 瀑布流布局实现详解
  10. JavaScript基础知识必知!