标贝科技 https://ai.data-baker.com/#/?source=qwer12

填写邀请码fwwqgs,每日免费调用量还可以翻倍

1、pytorch和libtorch安装(标贝科技)

PyTorch 是Torch7 团队开发的,从它的名字就可以看出,其与Torch 的不同之处在于PyTorch 使用了Python 作为开发语言。所谓“Python first”,同样说明它是一个以Python 优先的深度学习框架,不仅能够实现强大的GPU 加速,同时还支持动态神经网络,这是现在很多主流框架比如Tensorflow 等都不支持的。
  PyTorch 既可以看做加入了GPU 支持的numpy,同时也可以看成一个拥有自动求导功能的强大的深度神经网络,除了Facebook 之外,它还已经被Twitter、CMU 和Salesforce 等机构采用。
pytorch是一个强大的机器学习库,其中集成了很多方法,但从python本身角度讲,它的速度还不够快,虽然对于许多需要动态性和易迭代性的场景来说,Python是一种合适且首选的语言,但在同样的情况下,Python的这些特性恰恰是不利的。它常常应用于生产环境,这是一个低延迟和有严格部署要求的领域,一般选择C++。

1)安装pytorch

两种方式安装pytorch:根据实际cuda版本和需求安装对应版本pytorch,这里安装的是1.5.0版本。

a.查看cuda版本

cat /usr/local/cuda/version.txt

得到cuda版本,安装合适版本的pytorch。

b.使用pip安装

pip install torch==1.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

c.使用conda安装

下载anaconda
wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-5.2.0-Linux-x86_64.sh
安装anconda
bash Anaconda3-5.2.0-Linux-x86_64.sh
创建test_torch虚拟环境,python版本=3.6
conda create -n test_torch python=3.6
激活test_torch虚拟环境
conda activate test_torch
安装pytorch
conda install torch=1.5.0

d.torch测试

import torch
torch.version

2)安装libtorch

a.确定libtorch版本

使用libtorch调用c++接口,要保证下载的libtorch的版本和pytorch的版本对应,使用低版本的pytorch和高版本的libtorch是没法成功的。根据pytorch和cuda版本确认libtorch版本
查看libtorch版本
https://blog.csdn.net/lxx4610/article/details/105806017/
https://pytorch.org/get-started/locally/

b.从官网下载编译好的文件

获取libtorch有两种方式:
• 从官网下载编译好的文件
https://pytorch.org/ 下载对应版本

c.自己进行源码编译

码云下载
git clone https://gitee.com/mirrors/pytorch.git
查看libtorch版本
git tag
查看当前分支
git branch
根据cuda和pytorch版本切换到适配的版本
git checkout v1.2.0
更新第三方库
git submodule update --init --recursive
编译
mkdir build
cd build
python …/tools/build_libtorch.py

2、使用pytorch训练模型

这里就不展开介绍

3、将Pytorch模型转化为Torch Script

Torch Script可以完好的表达pytorch模型,而且也能被C++头文件所理解。有两种方法可以将pytorch模型转换成TorchScript,Tracing和Annotation。

1)Tracing

这种方法需要你给模型传入一个sample input,它会跟踪在模型的forward方法中的过程。
例如,加载一个torchvision.models.resnet18()模型
model = torchvision.models.resnet18()
使用 torch.rand(),生成一个随机样例输入
example = torch.rand(1, 3, 224, 224)
torch.jit.trace()方法对根据样例输入跟踪模型的forward方法中的过程
traced_script_module = torch.jit.trace(model, example)
最后导出TorchScript模型。
traced_script_module.save(“traced_resnet_model.pt”)
完整过程:

import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example) traced_script_module.save("traced_resnet_model.pt")

2)Annotation

如果forward方法中具有判断语句,Tracing方法就行不通了,Annotation方法则可以处理模型里有判断语句的情形,使用torch.jit.script。
模型定义:

class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__()self.weight = torch.nn.Parameter(torch.rand(N, M)) def forward(self, input): if input.sum() > 0: output = self.weight.mv(input) else: output = self.weight + input return output

定义了一个模型结构,在forward方法中使用了判断语句,这种模型在转化为Torch Script时,不能Tracing方法,这时可以使用 torch.jit.script()方法:

my_module = MyModule(10,20)
traced_script_module = torch.jit.script(my_module)
traced_script_module.save("traced_resnet_model.pt")

4、在C++中加载Model

将pytorch训练好的模型导出成torch script形式并保存,C++能够理解,编译并序列化torch script格式的模型。
使用libtorch中torch::jit::load()加载导出的模型。

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) { if (argc != 2) {std::cerr << "usage: example-app <path-to-exported-script-module>\n";return -1; } torch::jit::script::Module module; try { // Deserialize the ScriptModule from a file using torch::jit::load(). module = torch::jit::load(argv[1]); } catch (const c10::Error& e) {std::cerr << "error loading the model\n"; return -1; } std::cout << "ok\n";}

torch::jit::load()函数用来加载模形,参数为模型文件名,返回torch::jit::script::Module类,<torch/script.h>头文件包含了需要的类和方法,这个文件通过安装libtorch得到。

5、运行模型

模型已经导入成功,使用libtorch中的一些方法,你就可以像在python中一样去跑你的模型了,并根据c++模型的输出与python模型的输出,对比结果。

// Create a vector of inputs. std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));// Execute the model and turn its output into a tensor. at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

欢迎体验标贝语音开放平台
地址:https://ai.data-baker.com/#/?source=qaz123
(注:填写邀请码hi25d7,每日免费调用量还可以翻倍)
​​​​​​

pytorch模型从训练到LibTorch部署(标贝科技)相关推荐

  1. 标贝科技推出「留声机」TTS方案,高还原、个性化声效提升交互意愿

    3 月 5 日,标贝科技推出全新 「留声机」,该方案具有高原度复刻效果,用户只需 5 分钟左右即可完成录制,训练过程全自动化处理,大约 2 小时后,即可拥有媲美原声声音的个性化模型,轻松实现文本转语音 ...

  2. 直击标贝科技WAIC2019:深耕语音合成与数据服务 助力语音场景完美落地

    8月29日,WAIC2019世界人工智能大会于上海世博中心&上海世博展览馆举办.本届大会以"智联世界 无限可能"为主题,以"高端化.国际化.专业化.市场化.智能化 ...

  3. 标贝科技声音克隆技术赋能 定制语音功能让陪伴触手可及

    "常回家看看,回家看看,哪怕帮妈妈刷刷筷子洗洗碗......"这首脍炙人口的老歌道出了多少父母的期盼,又是多少儿女的遗憾.因为生活,因为工作,我们总是很忙,忙得没有时间回家,甚至打 ...

  4. 手把手教你用JAVA实现“声音复刻”功能(复刻你的声音)标贝科技

    手把手教你用JAVA实现"声音复刻"功能(复刻你的声音)标贝科技 前言 什么是声音复刻? 使用少量的用户声音,短时间内快速为用户量身打造个人定制音色 一.内容太长不愿意看,直接使用 ...

  5. kaldi新手入门及语音识别的流程(标贝科技)

    kaldi新手入门及语音识别的流程(标贝科技) 欢迎体验标贝语音开放平台 地址:https://ai.data-baker.com/#/?source=qaz123 (注:填写邀请码hi25d7,每日 ...

  6. GMM-HMM声学模型实例详解(标贝科技)

    欢迎测试标贝科技AI开放平台 https://ai.data-baker.com/#/?source=qwer12 GMM-HMM声学模型实例详解 GMM-HMM为经典的声学模型,基于深度神经网络的语 ...

  7. 标贝科技语音论文入选全球顶级语音学术大会INTERSPEECH2019

    全球知名语音学术大会INTERSPEECH2019于9月15日至19日在奥地利格拉茨城市举行. 作为全球智能语音及AI数据发展的推动者,标贝科技受邀成为大会黄金级赞助厂商亮相现场.其中,由标贝语音团队 ...

  8. 标贝科技亮相2019中国互联网大会 解决语音合成定制需求痛点

    2019中国互联网大会于7月9日-11日在北京国家会议中心举行.本次大会以"创新求变再出发"为主题,开设物联网.人工智能等系列分支论坛,汇聚国内重点科研机构及众多知名互联网及人工智 ...

  9. Gowild狗尾草推出HE琥珀,标贝科技为其提供更“温柔”的声音

    ​​8月22日,"2018Gowild狗尾草品牌发布会"在中国电影导演中心举行.会上Gowild狗尾草公布了"AI虚拟生命"大战略,并发布了基于大战略之下的新一 ...

最新文章

  1. 【Spring Security】五、自定义过滤器
  2. 深度学习之误差反向传播法
  3. 【搜遍互联网,集百家之长】环境配置从入门到放弃之Mac环境下,安装XAMPP,给phpstorm安装Xdebug调试工具...
  4. Java:抽象方法和抽象类,抽象类应用模板方法模式,接口及使用
  5. 中国通货膨胀率2.8%,数据分析买房风险直线上升
  6. Linux安装/升级pip
  7. Django 自定义表名
  8. 基于openstack搭建百万级并发负载均衡器的解决方案
  9. java switch case怎么判断范围_【转】Java期末复习攻略!
  10. Linux下java/bin目录下的命令集合
  11. java -jar 内存溢出_JAVA系统启动栈内存溢出-StackOverflowError
  12. mysql 存储汉字_MySQL存储汉字
  13. java 交互式 shell_Java9 Shell工具(JShell)
  14. YGC 问题排查,又涨姿势了!
  15. case / switch语句的Python等价物是什么? [重复]
  16. 微程序控制器的组成及原理总结
  17. 华为轮值董事长郭平:美国在5G方面已落后
  18. 2021秋季软件工程实践总结
  19. 学好数据结构的重要性
  20. Apache是干什么的?

热门文章

  1. Web开发基础-新闻页面-老九门
  2. 纯 js 实现跨域接口调用 jsonp
  3. 实用拜占庭容错算法 (PBFT)
  4. 小而美的Nginx日志分析利器GoAccess
  5. PHP json_decode()报错 json_last_error()判断错误类型 解决
  6. PHP函数json_decode的用法,PHP json_decode()用法及代码示例
  7. Arch Linux 记录
  8. 微信机器人接入Midjourney
  9. 特征点提取opencv
  10. ssh配置免密登录、scp文件传输免密