使用libtorch将pytorch 部署到移动端去
使用resnet50举例
使用代码将pytorch模型转为libtorch

import torch
import torchvision.models as models
from PIL import Image
import numpy as np
#加载测试图片调整大小
image = Image.open("test.jpg") #图片发在了build文件夹下
image = image.resize((224, 224),Image.ANTIALIAS)
#进行预处理
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
#变换维度
image = image.permute((0, 3, 1, 2)).float()
#加载使用pytorch自带resnet50模型
model = models.resnet50(pretrained=True)
model = model.eval()resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
#使用测试模型转换
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000类的类别序
#保存转化后的模型
resnet.save('resnet.pt')

使用c++调用写好的模型
再c++调用模型之前要先写好CMakeLists文件

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)project(example_torch)
set(CMAKE_PREFIX_PATH "XXX/libtorch") //注意这里填自己解压libtorch时的路径find_package(Torch REQUIRED)
find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)find_package(OpenCV 2.4.3 QUIET)if(NOT OpenCV_FOUND)message(FATAL_ERROR "OpenCV > 2.4.3 not found.")endif()
endif()
add_executable(${PROJECT_NAME} "main.cpp")
target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 11)```

其中要设置好CMAKE_PREFIX_PATH路径,这个路径就是我们解压libtorch的路径,不然无法链接到libtorch库,其中也设置了OpenCV的配置,具体OpenCV的安装这里介绍了。
然后就是C++调用PyTorch模型的代码


#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
#include <vector>
#include <opencv2/highgui.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/opencv.hpp>void TorchTest(){std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../resnet.pt");assert(module != nullptr);std::cout << "Load model successful!" << std::endl;std::vector<torch::jit::IValue> inputs;inputs.push_back(torch::zeros({1,3,224,224}));at::Tensor output = module->forward(inputs).toTensor();auto max_result = output.max(1, true);auto max_index = std::get<1>(max_result).item<float>();std::cout << max_index << std::endl;
}void Classfier(cv::Mat &image){torch::Tensor img_tensor = torch::from_blob(image.data, {1, image.rows, image.cols, 3}, torch::kByte);img_tensor = img_tensor.permute({0, 3, 1, 2});img_tensor = img_tensor.toType(torch::kFloat);img_tensor = img_tensor.div(255);std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../Train/resnet.pt");torch::Tensor output = module->forward({img_tensor}).toTensor();auto max_result = output.max(1, true);auto max_index = std::get<1>(max_result).item<float>();std::cout << max_index << std::endl;}int main() {
//    TorchTest();cv::Mat image = cv::imread("airliner.jpg");cv::resize(image,image, cv::Size(224,224));std::cout << image.rows <<" " << image.cols <<" " << image.channels() << std::endl;Classfier(image);return 0;
}

其中TorchTest函数只是做了简单的演示,而Classfier通过OpenCV读取图片,并通过libtorch的函数将Mat格式转换成Tensor(注意:这里转换了维度,因为OpenCV的维度是[H,W,C], 而PyTorch模型需要的是[C,H,W]),最后依然能够输出和Python代码一样的答案。

这里比较重要的几个函数有:

torch::from_blob(): 这个函数将Mat类型转换成Tensor类型。

torch::jit::load(): 该函数顾名思义就是加载模型的函数。

module->forward(): 模型前向传播的函数,输入值建议使用vector类型

max(): 这个函数是libtorch中的max,返回c++中的tuple类型(第一个值为维度上最大值,第二个值为最大值的序号)所以使用std::get<1>(max_result)来取出序号,这是tuple类型取出方式。

模型转换

libtorch不依赖于python,python训练的模型,需要转换为script model才能由libtorch加载,并进行推理。在这一步官网提供了两种方法:

方法一:Tracing

这种方法操作比较简单,只需要给模型一组输入,走一遍推理网络,然后由torch.ji.trace记录一下路径上的信息并保存即可。示例如下:

import torch
import torchvision# An instance of your model.
model = torchvision.models.resnet18()# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

缺点是如果模型中存在控制流比如if-else语句,一组输入只能遍历一个分支,这种情况下就没办法完整的把模型信息记录下来。
意思就是说这个方法只能转换一波流的网络

方法二:Scripting

直接在Torch脚本中编写模型并相应地注释模型,通过torch.jit.script编译模块,将其转换为ScriptModule。示例如下:

class MyModule(torch.nn.Module):def __init__(self, N, M):super(MyModule, self).__init__()self.weight = torch.nn.Parameter(torch.rand(N, M))def forward(self, input):if input.sum() > 0:output = self.weight.mv(input)else:output = self.weight + inputreturn outputmy_module = MyModule(10,20)
sm = torch.jit.script(my_module)

forward方法会被默认编译,forward中被调用的方法也会按照被调用的顺序被编译
如果想要编译一个forward以外且未被forward调用的方法,可以添加 @torch.jit.export.
如果想要方法不被编译,可使用@torch.jit.ignore 或者 @torch.jit.unused
##例子如下:

# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():return 2# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():return 2# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():import pdb; pdb.set_trace()return 4# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():return 2

不是所有方法都支持的
https://pytorch.org/docs/master/jit_unsupported.html#jit-unsupported

1. 不支持的操作TorchScript支持的操作是python的子集,大部分torch中用到的操作都可以找到对应实现,但也存在一些尴尬的不支持操作,详细列表可见unsupported-ops,下面列一些我自己遇到的操作:1)参数/返回值不支持可变个数,例如def __init__(self, **kwargs):
或者if output_flag == 0:return reshape_logits
else:loss = self.loss(reshape_logits, term_mask, labels_id)return reshape_logits, loss
2)各种iteration操作eg1.layers = [int(a) for a in layers]
报错torch.jit.frontend.UnsupportedNodeError: ListComp aren’t supported可以改成:for k in range(len(layers)):layers[k] = int(layers[k])
eg2.seq_iter = enumerate(scores)
try:_, inivalues = seq_iter.__next__()
except:_, inivalues = seq_iter.next()
eg3.line = next(infile)
3)不支持的语句eg1. 不支持continuetorch.jit.frontend.UnsupportedNodeError: continue statements aren’t supportedeg2. 不支持try-catchtorch.jit.frontend.UnsupportedNodeError: try blocks aren’t supportedeg3. 不支持with语句4)其他常见op/moduleeg1. torch.autograd.Variable解决:使用torch.ones/torch.randn等初始化+.float()/.long()等指定数据类型。eg2. torch.Tensor/torch.LongTensor etc.解决:同上eg3. requires_grad参数只在torch.tensor中支持,torch.ones/torch.zeros等不可用eg4. tensor.numpy()eg5. tensor.bool()解决:tensor.bool()用tensor>0代替eg6. self.seg_emb(seg_fea_ids).to(embeds.device)解决:需要转gpu的地方显示调用.cuda()总之一句话:除了原生python和pytorch以外的库,比如numpy什么的能不用就不用,尽量用pytorch的各种API。2. 指定数据类型1)属性,大部分的成员数据类型可以根据值来推断,空的列表/字典则需要预先指定from typing import Dictclass MyModule(torch.nn.Module):my_dict: Dict[str, int]def __init__(self):super(MyModule, self).__init__()# This type cannot be inferred and must be specifiedself.my_dict = {}# The attribute type here is inferred to be `int`self.my_int = 20def forward(self):passm = torch.jit.script(MyModule())
2)常量,使用Final关键字try:from typing_extensions import Final
except:# If you don't have `typing_extensions` installed, you can use a# polyfill from `torch.jit`.from torch.jit import Finalclass MyModule(torch.nn.Module):my_constant: Final[int]def __init__(self):super(MyModule, self).__init__()self.my_constant = 2def forward(self):passm = torch.jit.script(MyModule())
3)变量。默认是tensor类型且不可变,所以非tensor类型必须要指明def forward(self, batch_size:int, seq_len:int, use_cuda:bool):

#方法三:Tracing and Scriptin混合

1)属性,大部分的成员数据类型可以根据值来推断,空的列表/字典则需要预先指定from typing import Dictclass MyModule(torch.nn.Module):my_dict: Dict[str, int]def __init__(self):super(MyModule, self).__init__()# This type cannot be inferred and must be specifiedself.my_dict = {}# The attribute type here is inferred to be `int`self.my_int = 20def forward(self):passm = torch.jit.script(MyModule())
2)常量,使用Final关键字try:from typing_extensions import Final
except:# If you don't have `typing_extensions` installed, you can use a# polyfill from `torch.jit`.from torch.jit import Finalclass MyModule(torch.nn.Module):my_constant: Final[int]def __init__(self):super(MyModule, self).__init__()self.my_constant = 2def forward(self):passm = torch.jit.script(MyModule())
3)变量。默认是tensor类型且不可变,所以非tensor类型必须要指明def forward(self, batch_size:int, seq_len:int, use_cuda:bool):
方法三:Tracing and Scriptin混合一种是在trace模型中调用script,适合模型中只有一小部分需要用到控制流的情况,使用实例如下:import torch@torch.jit.script
def foo(x, y):if x.max() > y.max():r = xelse:r = yreturn rdef bar(x, y, z):return foo(x, y) + ztraced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
另一种情况是在script module中用tracing生成子模块,对于一些存在script module不支持的python feature的layer,就可以把相关layer封装起来,用trace记录相关layer流,其他layer不用修改。使用示例如下:import torch
import torchvisionclass MyScriptModule(torch.nn.Module):def __init__(self):super(MyScriptModule, self).__init__()self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1))self.resnet = torch.jit.trace(torchvision.models.resnet18(),torch.rand(1, 3, 224, 224))def forward(self, input):return self.resnet(input - self.means)my_script_module = torch.jit.script(MyScriptModule())
2.保存序列化模型
如果上一步的坑都踩完,那么模型保存就非常简单了,只需要调用save并传递一个文件名即可,需要注意的是如果想要在gpu上训练模型,在cpu上做inference,一定要在模型save之前转化,再就是记得调用model.eval(),形如gpu_model.eval()
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(traced_cpu, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pth")traced_gpu = torch.jit.trace(traced_gpu, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pth")
3.C++ load训练好的模型
要在C ++中加载序列化的PyTorch模型,必须依赖于PyTorch C ++ API(也称为LibTorch)。libtorch的安装非常简单,只需要在pytorch官网下载对应版本,解压即可。会得到一个结构如下的文件夹。libtorch/bin/include/lib/share/
然后就可以构建应用程序了,一个简单的示例目录结构如下:example-app/CMakeLists.txtexample-app.cpp
example-app.cpp和CMakeLists.txt的示例代码分别如下:#include <torch/script.h> // One-stop header.#include <iostream>
#include <memory>int main(int argc, const char* argv[]) {if (argc != 2) {std::cerr << "usage: example-app <path-to-exported-script-module>\n";return -1;}torch::jit::script::Module module;try {// Deserialize the ScriptModule from a file using torch::jit::load().module = torch::jit::load(argv[1]);}catch (const c10::Error& e) {std::cerr << "error loading the model\n";return -1;}std::cout << "ok\n";
}cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)find_package(Torch REQUIRED)add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
至此,就可以运行以下命令从example-app/文件夹中构建应用程序啦:mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
其中/path/to/libtorch是之前下载后的libtorch文件夹所在的路径。这一步如果顺利能够看到编译完成100%的提示,下一步运行编译生成的可执行文件,会看到“ok”的输出,可喜可贺!4. 执行Script Module
终于到最后一步啦!下面只需要按照构建输入传给模型,执行forward就可以得到输出啦。一个简单的示例如下:// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
前两行创建一个torch::jit::IValue的向量,并添加单个输入. 使用torch::ones()创建输入张量,等效于C ++ API中的torch.ones。 然后,运行script::Module的forward方法,通过调用toTensor()将返回的IValue值转换为张量。C++对torch的各种操作还是比较友好的,通过torch::或者后加_的方法都可以找到对应实现,例如torch::tensor(input_list[j]).to(at::kLong).resize_({batch, 128}).clone()
//torch::tensor对应pytorch的torch.tensor; at::kLong对应torch.int64;resize_对应resize
最后check一下确保c++端的输出和pytorch是一致的就大功告成啦~踩了无数坑,薅掉了无数头发,很多东西也是自己一点点摸索的,如果有错误欢迎指正!参考资料:PyTorch C++ API - PyTorch master documentTorch Script - PyTorch master documentation

#实际例子
在python代码中直接将tensor保存为pt文件,然后使用如下代码将其转化为c++可以读取的格式:

class Container(torch.nn.Module):def __init__(self, my_values):super().__init__()for key in my_values:setattr(self, key, my_values[key])my_values = {'img': torch.load('img.pt'),'proposal_list': torch.load('proposal_list.pt'),'cls_score': torch.load('cls_score.pt'),'bbox_pred': torch.load('bbox_pred.pt'),'det_bboxes': torch.load('det_bboxes.pt'),'det_labels': torch.load('det_labels.pt'),
}

Save arbitrary values supported by TorchScript
https://pytorch.org/docs/master/jit.html#supported-type

container = torch.jit.script(Container(my_values))
container.save("results.pt")

在c++使用如下代码加载results.pt文件并与c++结果比对:

torch::jit::script::Module results = torch::jit::load("results.pt");
Tensor std_img = results.attr("img").toTensor().to(device);
//auto feature_maps = results.attr("img").toTuple();
//auto cls_out = results.attr("img").toTensorVector();
std::cout << torch::sum(torch::abs(std_img - img)) << std::endl;

参考
https://zhuanlan.zhihu.com/p/263626686
https://zhuanlan.zhihu.com/p/146453159
https://zhuanlan.zhihu.com/p/72750321

pytorch转libtorch,全网最全资料相关推荐

  1. Flink 全网最全资源(视频、博客、PPT、入门、原理、实战、性能调优、源码解析、问答等持续更新)

    Flink 学习 https://github.com/zhisheng17/flink-learning 麻烦路过的各位亲给这个项目点个 star,太不易了,写了这么多,算是对我坚持下来的一种鼓励吧 ...

  2. Flink 全网最全资源(视频、博客、PPT、入门、实战、源码解析、问答等持续更新)...

    Flink 学习 github.com/zhisheng17/- 麻烦路过的各位亲给这个项目点个 star,太不易了,写了这么多,算是对我坚持下来的一种鼓励吧! 本项目结构 博客 1.Flink 从0 ...

  3. GitHub:TensorFlow、PyTorch最全资料集锦

    给各位小伙伴们推出几个深度学习框架的资料集锦,统一命名为:XXX-From-Zero-To-One.下面po一幅深度学习框架发展的重要历史点: 从上图可知,TensorFlow和PyTorch是目前深 ...

  4. 【免费下载】全网最全5G资料包(报告、白皮书、方案、政策等1300余份,持续更新)...

    大家好,我是文文(微信:sscbg2020),今天给大家盘点一下今年大热的5G相关的行业报告. 移动通信从面向个人通信的1G.2G.3G.4G以十年一代的速度发展到现在,面向产业互联网和智慧城市应用的 ...

  5. 可能是全网最全,JAVA日志框架适配/冲突解决方案,可以早点下班了

    点击关注公众号,Java干货及时送达 你是否遇到过配置了日志,但打印不出来的情况? 你是否遇到过配置了logback,启动时却提示log4j错误的情况?像下面这样: log4j:WARN No app ...

  6. GitHub:人群密度估计最全资料集锦

    GitHub:人群密度估计最全资料集锦 文章目录: 整理过的awesome系列项目: GitHub:车道线检测最全资料集锦 GitHub:目标检测最全论文集锦 GitHub:TensorFlow最全资 ...

  7. 全网最全斗音短视频新老账号起号技巧

    大家好,我是我赢助手,专注于自媒体短视频去水印.去重和文案提取运营. 今天给大家分享下全网最全斗音短视频新老账号起号技巧 1.清理手机(手机登录新斗音号才需要清理) 安卓:打开设置-应用管理-斗音短视 ...

  8. 全网最全-超大模型+分布式训练架构和经典论文

    如何利用计算中心成千上百的AI加速芯片的集群,训练参数量超过百亿的大规模模型?并行计算是一种行之有效的方法,除了分布式并行计算相关的技术之外,其实在训练大模型的过程还会融合更多的技术,如新的算法模型架 ...

  9. 黑猫带你学UFS协议第1篇:全网最全UFS协议中文详讲,这份学习框架图,你值得拥有!!!(持续更新中...)

    文/黑猫学长 1 作者想说 笔者本人从事于存储芯片行业多年,对eMMC/UFS/SD等芯片有深入研究,协议尤甚.而今看来,UFS协议在整个存储产品中(包括U盘.SPI.SD卡,NM卡.emmc.SSD ...

最新文章

  1. java 登录下线_java 实现 一个账号只能在一个地方登陆,其他地方被下线
  2. 利用yarn capacity scheduler在EMR集群上实现大集群的多租户的集群资源隔离和quota限制...
  3. 程序是在RAM里还是flash里执行
  4. 每次执行java命令 都要source_跟着平台混了四年,现在要单飞了!
  5. 关于Netty的一些理解、实践与陷阱
  6. C语言:输入两个数,输出最大公约数,最小公倍数
  7. 等午饭吃过后的dwzjzx
  8. 凌云一周看点 | 混合云多Region架构;云上用户定制化网络;边缘云全站加速;什么是操作系统的云原生...
  9. Halcon学习笔记之OCR系列-喷码字体识别
  10. Shiro面试题(二十道)
  11. 批量打印word文档_如何安排打印Word 2007+文档
  12. win7下chm打不开
  13. 平面设计之PS(前)
  14. Excel学习笔记 - 查找表格数据
  15. 如何在微信中下载APP
  16. Oracle常见问题一千问
  17. 8051单片机(STC89C52)定时器实现10ms精准定时
  18. centos 安装 pcre
  19. Docker下elasticsearch8部署、扩容、基本操作实战(含kibana)
  20. python中getattr()和setattr()的使用

热门文章

  1. 【java8】中stream的.findAny().orElse (null) 是什么意思?
  2. Navicat中查询哪些表有指定的字段名(技巧)
  3. oracle三种分区的方式,Oracle 分区表 总结大全(3)
  4. python爬虫提取教学_python爬虫的基本抓取
  5. storyboard搭建项目_Storyboard 快速搭建UICollectionView
  6. Linux系统备份树莓派,全平台备份树莓派的方法
  7. 小程序在输入npm命令_微信小程序使用npm包步骤
  8. HDLBits 系列(28)PS/2 mouse protocol(PS/2 packet parser)
  9. Cadence入门笔记(2):分裂元件的制作方法
  10. 【 Sublime Text 】如何使用Sublime Text快速生成代码模板