现如今,在边缘设备上运行机器学习/深度学习变得越来越流行,它需要更低的时延。

而从Pytorch 1.3开始,我们就可以使用Pytorch将模型部署到Android或者ios设备中。

Pytorch官方文档中提供两个关于Pytorch-mobile的Demo: Github地址


主要包含了两个APP应用,一个简单的在神经网络领域中的“hello world"项目,另一个就更复杂了一些,有图形识别和语言识别。

我们接下来研究一下Pytorch Mobile的项目流程。


Demo 1 HelloWorldApp

1 模型准备

首先我们需要先训练好的模型保存好。比如我在Pycharm写了经典CNN模型AlexNet。

checkpoints/ 文件夹中保存了 AlexNet.pt,有了这个模型,我们就可以进行Android的部署了。


2 源码分析

2.1 Clone 源码

我们先在本地clone一下github上的源码(吐槽一下git clone的速度,龟速!):

git clone https://github.com/pytorch/android-demo-app.git

然后便得到这个项目。

前提先确保一下Android安装好了SDK和NDK。

2.2 向 Gradle 添加依赖

然后我们会在 app 下的 build.gradle 中发现这样的依赖:


最下面两行中的

  • org.pytorch:pytorch_android : Pytorch Android API 的主要依赖,包含为4个Android abis (armeabi-v7a, arm64-v8a, x86, x86_64) 的 libtorch 本地库。
  • org.pytorch:pytorch_android_torchvision:它是具有将 android.media.imageandroid.graphics.bitmap 转换为 Tensor 的附加库。

2.3 读取图片数据

MainActivity.java文件中,有这么一行:

bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

Bitmap 为位图,其包括像素以及长、宽、颜色等描述信息。长、宽、像素位数用来描述图片,并可以通过这些信息计算出图片的像素占用内存的大小。

通过 BitmapFactory.decodeStream( ) 这一函数加载图像。

2.4 读取模型

同样在 MainActivity.java文件中,有这么一行:

module = Module.load(assetFilePath(this, "model.pt"));

当然我们需要 import org.pytorch.Module
然后通过Module定义一个对象后使用 Module.load() 来读取模型。

2.5 将图像转化为Tensor

在这么一行中:


org.pytorch.torchvision.TensorImageUtils就是org.pytorch:pytorch_android_torchvision库中的一部分,TensorImageUtils.bitmapToFloat32Tensor 创建一个Tensor类型。

inputTensor 的 大小为 1x3xHxW, 其中 H 和 W 分别为 Bitmap 的高和宽。

2.6 运行模型


将 inputTensor 放到模型中运行,通过 module.forward() 得到一个 outputTensor。

2.7 处理结果

    // getting tensor content as java array of floatsfinal float[] scores = outputTensor.getDataAsFloatArray();// searching for the index with maximum scorefloat maxScore = -Float.MAX_VALUE;int maxScoreIdx = -1;for (int i = 0; i < scores.length; i++) {if (scores[i] > maxScore) {maxScore = scores[i];maxScoreIdx = i;}}String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];// showing className on UITextView textView = findViewById(R.id.text);textView.setText(className);

判断最高分数,并将结果显示到textView中。


Demo2 PytorchDemoApp

这是另一个Demo App,它可以进行图像分类和文字分类。而图像分类就需要利用摄像头。

摄像头API通过使用 org.pytorch.demo.vision.AbstractCameraXActivity 类。
在 AbstractCameraXActivity.java 中的具体源码如下:

  private void setupCameraX() {final TextureView textureView = getCameraPreviewTextureView();// 实现摄像头预览final PreviewConfig previewConfig = new PreviewConfig.Builder().build();final Preview preview = new Preview(previewConfig);preview.setOnPreviewOutputUpdateListener(output -> textureView.setSurfaceTexture(output.getSurfaceTexture()));// 实现数据分析并回调final ImageAnalysisConfig imageAnalysisConfig =new ImageAnalysisConfig.Builder().setTargetResolution(new Size(224, 224)).setCallbackHandler(mBackgroundHandler).setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE).build();final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);imageAnalysis.setAnalyzer((image, rotationDegrees) -> {if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {return;}final R result = analyzeImage(image, rotationDegrees);if (result != null) {mLastAnalysisResultTime = SystemClock.elapsedRealtime();runOnUiThread(() -> applyToUiAnalyzeImageResult(result));}});CameraX.bindToLifecycle(this, preview, imageAnalysis);}// analyzeImage函数是用来处理摄像头输出void analyzeImage(android.media.Image, int rotationDegrees)

而在 ImageClassificationActivity.java 中的源码如下:

protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {if (mAnalyzeImageErrorState) {return null;}try {if (mModule == null) {final String moduleFileAbsoluteFilePath = new File(Utils.assetFilePath(this, getModuleAssetName())).getAbsolutePath();// 导入模型mModule = Module.load(moduleFileAbsoluteFilePath);mInputTensorBuffer =Tensor.allocateFloatBuffer(3 * INPUT_TENSOR_WIDTH * INPUT_TENSOR_HEIGHT);mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, INPUT_TENSOR_HEIGHT, INPUT_TENSOR_WIDTH});}final long startTime = SystemClock.elapsedRealtime();// 将以YUV420形式的Image类型转化为输入TensorTensorImageUtils.imageYUV420CenterCropToFloatBuffer(image.getImage(), rotationDegrees,INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT,TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,TensorImageUtils.TORCHVISION_NORM_STD_RGB,mInputTensorBuffer, 0);final long moduleForwardStartTime = SystemClock.elapsedRealtime();// 利用模型进行运算final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;// 从模型中得到预测分数final float[] scores = outputTensor.getDataAsFloatArray();// 找到得分最高的前k个类final int[] ixs = Utils.topK(scores, TOP_K);final String[] topKClassNames = new String[TOP_K];final float[] topKScores = new float[TOP_K];for (int i = 0; i < TOP_K; i++) {final int ix = ixs[i];topKClassNames[i] = Constants.IMAGENET_CLASSES[ix];topKScores[i] = scores[ix];}final long analysisDuration = SystemClock.elapsedRealtime() - startTime;return new AnalysisResult(topKClassNames, topKScores, moduleForwardDuration, analysisDuration);} catch (Exception e) {Log.e(Constants.TAG, "Error during image analysis", e);mAnalyzeImageErrorState = true;runOnUiThread(() -> {if (!isFinishing()) {showErrorDialog(v -> ImageClassificationActivity.this.finish());}});return null;}}

最后将得到的前k个类加载到UI上。

protected void applyToUiAnalyzeImageResult(AnalysisResult result) {mMovingAvgSum += result.moduleForwardDuration;mMovingAvgQueue.add(result.moduleForwardDuration);if (mMovingAvgQueue.size() > MOVING_AVG_PERIOD) {mMovingAvgSum -= mMovingAvgQueue.remove();}for (int i = 0; i < TOP_K; i++) {final ResultRowView rowView = mResultRowViews[i];rowView.nameTextView.setText(result.topNClassNames[i]);rowView.scoreTextView.setText(String.format(Locale.US, SCORES_FORMAT,result.topNScores[i]));rowView.setProgressState(false);}mMsText.setText(String.format(Locale.US, FORMAT_MS, result.moduleForwardDuration));if (mMsText.getVisibility() != View.VISIBLE) {mMsText.setVisibility(View.VISIBLE);}mFpsText.setText(String.format(Locale.US, FORMAT_FPS, (1000.f / result.analysisDuration)));if (mFpsText.getVisibility() != View.VISIBLE) {mFpsText.setVisibility(View.VISIBLE);}if (mMovingAvgQueue.size() == MOVING_AVG_PERIOD) {float avgMs = (float) mMovingAvgSum / MOVING_AVG_PERIOD;mMsAvgText.setText(String.format(Locale.US, FORMAT_AVG_MS, avgMs));if (mMsAvgText.getVisibility() != View.VISIBLE) {mMsAvgText.setVisibility(View.VISIBLE);}}}

最后显示结果如下:

Pytorch Mobile 之Android Demo源码分析相关推荐

  1. Android HandlerThread 源码分析

    HandlerThread 简介: 我们知道Thread线程是一次性消费品,当Thread线程执行完一个耗时的任务之后,线程就会被自动销毁了.如果此时我们又有一 个耗时任务需要执行,我们不得不重新创建 ...

  2. Android ADB 源码分析(三)

    前言 之前分析的两篇文章 Android Adb 源码分析(一) 嵌入式Linux:Android root破解原理(二) 写完之后,都没有写到相关的实现代码,这篇文章写下ADB的通信流程的一些细节 ...

  3. 【Android SDM660源码分析】- 02 - UEFI XBL QcomChargerApp充电流程代码分析

    [Android SDM660源码分析]- 02 - UEFI XBL QcomChargerApp充电流程代码分析 一.加载 UEFI 默认应用程序 1.1 LaunchDefaultBDSApps ...

  4. 【Android SDM660源码分析】- 03 - UEFI XBL GraphicsOutput BMP图片显示流程

    [Android SDM660源码分析]- 03 - UEFI XBL GraphicsOutput BMP图片显示流程 1. GraphicsOutput.h 2. 显示驱动初化 DisplayDx ...

  5. 【Android SDM660源码分析】- 01 - 如何创建 UEFI XBL Protocol DXE_DRIVER 驱动及UEFI_APPLICATION 应用程序

    [Android SDM660源码分析]- 01 - 如何创建 UEFI XBL Protocol DXE_DRIVER 驱动及UEFI_APPLICATION 应用程序 一.创建DXE_DRIVER ...

  6. 【Android SDM660源码分析】- 04 - UEFI ABL LinuxLoader 代码分析

    [Android SDM660源码分析]- 04 - UEFI ABL LinuxLoader 代码分析 1. LinuxLoader.c 系列文章: <[Android SDM660开机流程] ...

  7. Android 音频源码分析——AndroidRecord录音(一)

    Android 音频源码分析--AndroidRecord录音(一) Android 音频源码分析--AndroidRecord录音(二) Android 音频源码分析--AndroidRecord音 ...

  8. Android框架源码分析——从设计模式角度看 Retrofit 核心源码

    Android框架源码分析--从设计模式角度看 Retrofit 核心源码 Retrofit中用到了许多常见的设计模式:代理模式.外观模式.构建者模式等.我们将从这三种设计模式入手,分析 Retrof ...

  9. 人人网官方Android客户端源码分析(1)

    ContentProvider是不同应用程序之间进行数据交换的标准API,ContentProvider以某种Uri的形式对外提供数据,允许其他应用访问或修改数据;其他应用程序使用ContentRes ...

最新文章

  1. 看完这些能控制大脑的寄生虫,你会怀疑人类!
  2. kafka多分区只有一个在消费_kafka多个消费者只有一个消费
  3. GNN笔记:傅里叶变换
  4. 高性能mysql看不懂_高性能mysql笔记1
  5. python中a%b_Python中的a+=b和a=a+b之间的区别是什么?
  6. C++之++操作符重载
  7. pthread线程传递数据回主线程_操作系统4:线程(1)
  8. 大数据如何应用在企业人力资源管理
  9. uds 诊断协议的bootloader开发
  10. Excel数据可视化表盘模板
  11. oracle的临时表
  12. 卸载抖音和微博的一天……
  13. 看好699指纹手机暴露任泉的商业野心
  14. AIDE手机编程初级教程(零基础向) 3.2.1 设计欢迎页(主体)
  15. 【Seen看世界】:像高智商人群看齐
  16. 2018计算机网络MOOC第一章作业1
  17. [LiteratureReview]PointNet Deep Learning on Point Sets for 3D Classification and Segmentation
  18. 使用nvml获取n卡温度
  19. 【词目】:勿谓言之不预也
  20. 【数据挖掘之关联规则实战】关联规则智能推荐算法

热门文章

  1. 小米无线路由器 + u盘
  2. 区块链益智冒险游戏NFT游戏开发
  3. mac 版微信音频设备启动失败
  4. 热门智能手表!OPPO华为苹果强强对决,哪款值得买?
  5. Spring-Data-Redis--解决java.lang.ClassCastException: java.util.LinkedHashMap cannot be cast to xxx
  6. javascript字符串拼接引号转义
  7. PVE安装Openwrt 旁路由
  8. php写动物的属性,写动物作文技巧与方法
  9. 计算机应用怎么样多项选择,【2017年整理】12春学期《计算机应用基础》在线作业第三次多项选择.doc...
  10. 计算机作业秋天的路,18秋西南大学《计算机基础》在线作业.doc