如何将pytorch模型部署到安卓上

这篇文章演示如何将训练好的pytorch模型部署到安卓设备上。我也是刚开始学安卓,代码写的简单。

环境:

pytorch版本:1.10.0

模型转化

pytorch_android支持的模型是.pt模型,我们训练出来的模型是.pth。所以需要转化才可以用。先看官网上给的转化方式:

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobilemodel = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")

这个模型在安卓对应的包:

repositories {jcenter()
}dependencies {implementation 'org.pytorch:pytorch_android_lite:1.9.0'implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}

注:pytorch_android_lite版本和转化模型用的版本要一致,不一致就会报各种错误。

目前用这种方法有点问题,我采用的另一种方法。

转化代码如下:

import torch
import torch.utils.data.distributed# pytorch环境中
model_pth = 'model_31_0.96.pth' #模型的参数文件
mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件model = torch.load(model_pth)
model.eval() # 模型设为评估模式
device = torch.device('cpu')
model.to(device)
# 1张3通道224*224的图片
input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式mobile = torch.jit.trace(model, input_tensor) # 模型转化
mobile.save(mobile_pt) # 保存文件

对应的包:

//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'

定义模型文件和转化后的文件路径。

load模型。这里要注意,如果保存模型

torch.save(model,'models.pth')

加载模型则是

model=torch.load('models.pth')

如果保存模型是

torch.save(model.state_dict(),"models.pth")

加载模型则是

model.load_state_dict(torch.load('models.pth'))

定义输入数据格式。

模型转化,然后再保存模型。

安卓部署

新建项目

新建安卓项目,选择Empy Activity,然后选择Next

然后,填写项目信息,选择安卓版本,我用的4.4,点击完成

导入包

导入pytorch_android的包

//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'

如果有参数报错请参照我的完整的配置,代码如下:

plugins {id 'com.android.application'
}android {compileSdk 32defaultConfig {applicationId "com.example.myapplication"minSdk 21targetSdk 32versionCode 1versionName "1.0"testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"}buildTypes {release {minifyEnabled falseproguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'}}compileOptions {sourceCompatibility JavaVersion.VERSION_1_8targetCompatibility JavaVersion.VERSION_1_8}
}dependencies {implementation 'androidx.appcompat:appcompat:1.3.0'implementation 'com.google.android.material:material:1.4.0'implementation 'androidx.constraintlayout:constraintlayout:2.0.4'testImplementation 'junit:junit:4.13.2'androidTestImplementation 'androidx.test.ext:junit:1.1.3'androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'//pytorchimplementation 'org.pytorch:pytorch_android:1.10.0'implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'}

页面文件

页面的配置如下:

<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"xmlns:tools="http://schemas.android.com/tools"android:layout_width="match_parent"android:layout_height="match_parent"tools:context=".MainActivity"><ImageViewandroid:id="@+id/image"android:layout_width="match_parent"android:layout_height="match_parent"android:scaleType="fitCenter" /><TextViewandroid:id="@+id/text"android:layout_width="match_parent"android:layout_height="wrap_content"android:layout_gravity="top"android:textSize="24sp"android:background="#80000000"android:textColor="@android:color/holo_red_light" /></FrameLayout>

这个页面只有两个空间,一个展示图片,一个显示文字。

模型推理

新增assets文件夹,然后将转化的模型和待测试的图片放进去。

新增ImageNetClasses类,这个类存放类别名字。

代码如下:

package com.example.myapplication;public class ImageNetClasses {public static String[] IMAGENET_CLASSES = new String[]{"Black-grass","Charlock","Cleavers","Common Chickweed","Common wheat","Fat Hen","Loose Silky-bent","Maize","Scentless Mayweed","Shepherds Purse","Small-flowered Cranesbill","Sugar beet",};
}

在MainActivity类中,增加模型推理的逻辑。完成代码如下:

package com.example.myapplication;import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;import org.pytorch.IValue;import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;import androidx.appcompat.app.AppCompatActivity;public class MainActivity extends AppCompatActivity {@Overrideprotected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);setContentView(R.layout.activity_main);Bitmap bitmap = null;Module module = null;try {// creating bitmap from packaged into app android asset 'image.jpg',// app/src/main/assets/image.jpgbitmap = BitmapFactory.decodeStream(getAssets().open("1.png"));// loading serialized torchscript module from packaged into app android asset model.pt,// app/src/model/assets/model.ptmodule = Module.load(assetFilePath(this, "models.pt"));} catch (IOException e) {Log.e("PytorchHelloWorld", "Error reading assets", e);finish();}// showing image on UIImageView imageView = findViewById(R.id.image);imageView.setImageBitmap(bitmap);// preparing input tensorfinal Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);// running the modelfinal Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();// getting tensor content as java array of floatsfinal float[] scores = outputTensor.getDataAsFloatArray();// searching for the index with maximum scorefloat maxScore = -Float.MAX_VALUE;int maxScoreIdx = -1;for (int i = 0; i < scores.length; i++) {if (scores[i] > maxScore) {maxScore = scores[i];maxScoreIdx = i;}}System.out.println(maxScoreIdx);String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];// showing className on UITextView textView = findViewById(R.id.text);textView.setText(className);}/*** Copies specified asset to the file in /files app directory and returns this file absolute path.** @return absolute file path*/public static String assetFilePath(Context context, String assetName) throws IOException {File file = new File(context.getFilesDir(), assetName);if (file.exists() && file.length() > 0) {return file.getAbsolutePath();}try (InputStream is = context.getAssets().open(assetName)) {try (OutputStream os = new FileOutputStream(file)) {byte[] buffer = new byte[4 * 1024];int read;while ((read = is.read(buffer)) != -1) {os.write(buffer, 0, read);}os.flush();}return file.getAbsolutePath();}}
}

然后运行。

如何将pytorch模型部署到安卓相关推荐

  1. TensorFlow与PyTorch模型部署性能比较

    TensorFlow与PyTorch模型部署性能比较 前言 2022了,选 PyTorch 还是 TensorFlow?之前有一种说法:TensorFlow 适合业界,PyTorch 适合学界.这种说 ...

  2. 基于C++的PyTorch模型部署

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...

  3. PyTorch模型部署:pth转onnx跨框架部署详解+代码

    文章目录 引言 基础概念 onnx:跨框架的模型表达标准 onnxruntime:部署模型的推理引擎 示例代码 0)安装onnx和onnxruntime 1)pytorch模型转onnx模型 2)on ...

  4. pyTorch模型部署--高并发web服务c++移动端ncnn

    文章目录 1 综述 2 以python web服务的形式进行部署 3 c++调用方式进行部署 3.1 torchscript 3.1.1 模型转换 3.1.1.1 torch.jit.trace 3. ...

  5. pytorch模型部署

    1. C++调用python训练的pytorch模型(一)--makefile编写基础 https://blog.csdn.net/xiake001/article/details/84838249 ...

  6. 春节小游戏之图片分类(Pytorch模型部署)

    文章目录 前言 环境 项目结构 前端 图片上传 结果显示 后端 模型部署 路由 业务代码 总结 本博文优先在掘金社区发布! 前言 啥也不说了,来先看效果图 本来我是打算去把昨天在实验平台训练的模型拿到 ...

  7. 学习记录——Pytorch模型移植Android小例子

    提示:注意文章时效性,2022.04.02. 目录 前言 零.使用的环境 一.模型准备 1.导出模型 2.错误记录 2.1要载入完整模型(网络结构+权重参数) 2.2导出的模型文件格式 二.Andro ...

  8. Sanic框架下部署Pytorch模型

    前言 本文针对业余范围的Pytorch模型部署,类似各位想把自己开发的深度学习模型上线web端demo等等. 大家比较熟悉的Python框架主要有flask,使用flask部署上线深度学习模型过程简单 ...

  9. 【Pytorch基础教程33】算法模型部署(MLFlow/ONNX/tf serving)

    内容概况 服务器上训练好模型后,需要将模型部署到线上,接受请求.完成推理并且返回结果. 保存模型结构和参数最简单的是torch.save保存为checkpoint,但一般用于训练时记录过程,训练中断可 ...

最新文章

  1. USACO Section1.3 Combination Lock 解题报告
  2. C++判断exe是32位还是64位
  3. 香港四大天王影帝情况(截止2016)
  4. RocketMQ(二):参数配置大全
  5. ZOJ3785 What day is that day? 快速幂+找规律
  6. mysql支持ASCII_MySQLASCII()函数返回字符的ASCII码值
  7. SpringBoot指南(五)——拦截器、原生组件
  8. 问题四十:对ray tracing圆环图形进行debug(2)——C++,用“笛卡尔”方法解一元四次方程
  9. ccf会议等级划分_Python计算山东新高考选考科目卷面原始成绩为等级成绩
  10. Hutool实现Excel导入导出
  11. 2021年美国联邦法定假日表
  12. Typora崩溃 与 设置备份
  13. 打开Flutter动画的另一种姿势——Flare,android面试题选择题
  14. 曼哈顿距离和欧氏距离
  15. Openvas的安装调试
  16. 计算机专业课改理念,课改新理念
  17. Web3.0 元宇宙 区块链
  18. 使用C++ OpenCV实现椭圆区域检测与Aruco码的生成与检测并估计位姿
  19. python严格使用缩进来体现代码的逻辑从属关系_Python 全国考级二级
  20. 云胶片(云影像)- 占用资源及费用估算

热门文章

  1. 我的架构梦:(二)MyBatis的一级、二级、分布式缓存的应用以及源码分析
  2. 对java封装特性的一些浅薄认识
  3. 1、OpenCV——图片的读、改、显、存操作函数
  4. 用于夜视和监控的图像增强方法
  5. android监控虚拟键盘,android虚拟键盘的监控,显示和隐藏
  6. unity lua C# 这边 new 了一个GameObject 对象并发给Lua那边, 这时C# 这边在通过GC释放掉这个对象;lua 那边会报错;遇到这种问题的解决方案
  7. matlab 版 数独小游戏 GUI界面设计
  8. 服务器主机本地系统6,服务器主机本地系统开机
  9. 兀键和6键怎么判断_你们不会的大π键(高三党,基础较好)
  10. 力扣第185场周赛总结——字节跳动专场