0. 前言

最近尝试着去在SLAM当中使用深度学习,而目前的SLAM基本上是基于C++的,而现有的Pytorch、Tensorflow这类框架均是基于python的。所以如何将Python这类脚本文件来在C++这类可执行文件中运行,这是非常有必要去研究的,而网络上虽然存在有例子,但是很多都比较杂乱,所以本篇文章将网络上常用的方法进行整理,以供后面初学者有迹可循

1. 模型认识

我们知道,目前基于C++存在两种方式,一种是通过Opencv加载训练好的模型和网络,而另一种则是通过TensorRT来进行C++的深度学习开发,TensorRT是Nvidia官方给的C++推理加速工具,如同OpenVINO之于Intel。支持诸多的AI框架,如Tensorflow,Pytorch,Caffe,MXNet等。


对于这类C++程序而言,其最重要是更加通用,同时支持模型自身运算的加速。

2. Opencv

实验流程为:Pytorch -> Onnx -> Opencv。即首先将Pytorch模型转换为Onnx模型,然后通过Opencv解析Onnx模型。

首先,参考pytorch官方文档中训练一个分类器的代码,训练一个简单的图像分类器。代码如下:

import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.onnx
import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=0)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=0)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# functions to show an imagedef imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()print(images.shape)# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 3)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 12, 3)self.conv3 = nn.Conv2d(12, 32, 3)self.fc1 = nn.Linear(32 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = F.relu(self.conv3(x))x = x.view(-1, 32 * 4 * 4)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
net.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(100):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data[0].to(device), data[1].to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint(outputs)print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

上述代码相对于官方文档的代码,仅仅是增加了卷积层和利用GPU进行训练,且输出结果未经处理,只是简单输出各个类别的概率值。
训练完网络之后,将网络保存,代码如下:

# 保存网络结构和参数# 方法1:保存网络结构和参数
PATH = './cifar_net.pth'
torch.save(net, PATH)# 方法2:保存网络参数
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)# 方法3:导出网络到ONNX
dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(net, dummy_input, "torch.onnx")# 方法4:保存网络位TORCHSCRIPT
dummy_input = torch.randn(1, 3, 32, 32).to(device)
traced_cell = torch.jit.trace(net, dummy_input)
traced_cell.save("tests.pth")

并通过方法3将保存好的ONNX模型输入到opencv中,并通过opencv提供的Net cv::dnn::readNetFromONNX ( const String & onnxFile )函数读取保存好的网络。代码实现如下:

//测试opencv加载pytorch模型
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;
#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace std;int main()
{String modelFile = "./torch.onnx";String imageFile = "./dog.jpg";dnn::Net net = cv::dnn::readNetFromONNX(modelFile); //读取网络和参数Mat image = imread(imageFile); // 读取测试图片cv::cvtColor(image, image, cv::COLOR_BGR2RGB);Mat inputBolb = blobFromImage(image, 0.00390625f, Size(32, 32), Scalar(), false, false); //将图像转化为正确输入格式net.setInput(inputBolb); //输入图像Mat result = net.forward(); //前向计算cout << result << endl;
}

3. TensorRT

实验流程为:Pytorch -> Onnx -> TensorRT 或者Pytorch-> TensorRT。创建TensorRT引擎及进行前向推理,下面将分成两节来描述不同的方法。

…详情请参照古月居

深度学习之从Python到C++相关推荐

  1. 深度学习入门 基于Python的理论与实现

    作者:斋藤康毅 出版社:人民邮电出版社 品牌:iTuring 出版时间:2018-07-01 深度学习入门 基于Python的理论与实现

  2. 深度学习 自组织映射网络 ——python实现SOM(用于聚类)

    深度学习 自组织映射网络 --python实现SOM(用于聚类) 摘要 python实现代码 计算实例 摘要 SOM(Self Organizing Maps ) 的目标是用低维目标空间的点来表示高维 ...

  3. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  4. 深度学习 + OpenCV,Python实现实时视频目标检测

    选自PyimageSearch 机器之心编译 参与:路雪.李泽南 使用 OpenCV 和 Python 对实时视频流进行深度学习目标检测是非常简单的,我们只需要组合一些合适的代码,接入实时视频,随后加 ...

  5. 深度学习之编程语言Python(Ⅰ)

    深度学习之编程语言Python 编程语言Python 什么是编程语言? 编程语言(Programming Language)就是人与计算机之间交互的方式.简单来说,就是人与计算机都可以理解的一种语言. ...

  6. 《深度学习入门——基于Python的理论与实现》笔记

    PS:写这篇博客主要是记录下自己认为重要的部分以及阅读中遇到的些问题,加深自己的印象. 附上电子书及源代码: 链接:https://pan.baidu.com/s/1f2VFcnXSSK-u3wuvg ...

  7. 深度学习入门-基于Python的理论入门与实现源代码加mnist数据集下载推荐

    深度学习入门-基于Python的理论入门与实现源代码加mnist数据集下载推荐 书籍封面 1-图灵网站下载 书里也说了,可以图灵网站下载https://www.ituring.com.cn/book/ ...

  8. 《深度学习入门-基于Python的理论与实现》学习笔记1

    <深度学习入门-基于Python的理论与实现>学习笔记1 第一章Python入门 Python是一个简单.易读.易记的编程语言,可以用类似于英语的语法进行编写程序,可读性高,且能写出高性能 ...

  9. 《深度学习入门--基于python的理论与实现》——斋藤康毅读书笔记

    <深度学习入门--基于python的理论与实现>读书笔记(第二章) 写在前面 第二章:感知机 2.1感知机是什么 2.2简单的逻辑电路 2.2.1与门(and gate) 2.2.2与非门 ...

  10. 基于深度学习OpenCV与python进行字符识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 当我们在处理图像数据集时,总是会想有没有什么办法以简单的文本格式检 ...

最新文章

  1. 【图论专题】单源最短路的综合应用
  2. GAN背后的理论依据,以及为什么只使用GAN网络容易产生
  3. caffe学习(一):开发环境搭建,编译caffe(win10)
  4. 如何绘制类似仓库的平面位置图
  5. 寻找数组变化:树形结构,分治模型
  6. MySQL查询语句中的IN 和Exists 对比分析
  7. ci框架 mysql 超时时间_mysql 字符集和校验规则( CHARSET amp; COLLATE)
  8. Linux 第20天: (09月12日) Linux启动和内核管理
  9. modbus连续读取时数据不正确_维纶触摸屏控制变频器是通过触摸屏与变频器之间的Modbus通信实现...
  10. 网络工程师交换试验手册之二十五:详细讲授利用xmodem来恢复IOS
  11. word根据目录切块php,PHP导出Word文档如何自定义目录?
  12. 消息中间件学习总结(12)——Kafka与RocketMQ的多Topic对性能稳定性的影响比较分析
  13. Mybatis Plus语法+示例
  14. vfp:数据库中表间关系的参照完整性
  15. 关于NLP相关技术全部在这里:预训练模型、信息抽取、文本生成、知识图谱、对话系统...
  16. HDFS基本命令及上传文件API
  17. Java sychronized关键字总结(二)
  18. win10无法装载iso文件_装载Win10 ISO镜像文件的具体方法
  19. 74HC595在【8x8LED点阵】中的运用
  20. 北京3月去哪玩 赏花踏青登山六大推荐

热门文章

  1. k8spod使用gpu
  2. python遗传算法_带有Python的AI –遗传算法
  3. 利用亚运会,读懂 Python装饰器
  4. 怎么用计算机编写文件,怎样在电脑上写作文做文件
  5. K_A02_001 基于单片机驱动4位数码管模块(74HC595) 0-3滚动+ 时钟显示
  6. java8新特性(拉姆达表达式lambda)
  7. ACM投稿版权信息去除问题
  8. 智慧燃气系统基于GIS技术的搭建
  9. Altera 逻辑锁定
  10. 大数据时代,做大数据开发要学Java框架吗?