Pytorch Mobile 之Android Demo源码分析
现如今,在边缘设备上运行机器学习/深度学习变得越来越流行,它需要更低的时延。
而从Pytorch 1.3开始,我们就可以使用Pytorch将模型部署到Android或者ios设备中。
Pytorch官方文档中提供两个关于Pytorch-mobile的Demo: Github地址
主要包含了两个APP应用,一个简单的在神经网络领域中的“hello world"项目,另一个就更复杂了一些,有图形识别和语言识别。
我们接下来研究一下Pytorch Mobile的项目流程。
Demo 1 HelloWorldApp
1 模型准备
首先我们需要先训练好的模型保存好。比如我在Pycharm写了经典CNN模型AlexNet。
在 checkpoints/
文件夹中保存了 AlexNet.pt
,有了这个模型,我们就可以进行Android的部署了。
2 源码分析
2.1 Clone 源码
我们先在本地clone一下github上的源码(吐槽一下git clone的速度,龟速!):
git clone https://github.com/pytorch/android-demo-app.git
然后便得到这个项目。
前提先确保一下Android安装好了SDK和NDK。
2.2 向 Gradle 添加依赖
然后我们会在 app
下的 build.gradle
中发现这样的依赖:
最下面两行中的
org.pytorch:pytorch_android
: Pytorch Android API 的主要依赖,包含为4个Android abis (armeabi-v7a, arm64-v8a, x86, x86_64) 的 libtorch 本地库。org.pytorch:pytorch_android_torchvision
:它是具有将android.media.image
和android.graphics.bitmap
转换为 Tensor 的附加库。
2.3 读取图片数据
在 MainActivity.java
文件中,有这么一行:
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Bitmap 为位图,其包括像素以及长、宽、颜色等描述信息。长、宽、像素位数用来描述图片,并可以通过这些信息计算出图片的像素占用内存的大小。
通过 BitmapFactory.decodeStream( )
这一函数加载图像。
2.4 读取模型
同样在 MainActivity.java
文件中,有这么一行:
module = Module.load(assetFilePath(this, "model.pt"));
当然我们需要 import org.pytorch.Module
然后通过Module定义一个对象后使用 Module.load()
来读取模型。
2.5 将图像转化为Tensor
在这么一行中:
org.pytorch.torchvision.TensorImageUtils
就是org.pytorch:pytorch_android_torchvision
库中的一部分,TensorImageUtils.bitmapToFloat32Tensor
创建一个Tensor类型。
inputTensor 的 大小为 1x3xHxW
, 其中 H 和 W 分别为 Bitmap 的高和宽。
2.6 运行模型
将 inputTensor 放到模型中运行,通过 module.forward()
得到一个 outputTensor。
2.7 处理结果
// getting tensor content as java array of floatsfinal float[] scores = outputTensor.getDataAsFloatArray();// searching for the index with maximum scorefloat maxScore = -Float.MAX_VALUE;int maxScoreIdx = -1;for (int i = 0; i < scores.length; i++) {if (scores[i] > maxScore) {maxScore = scores[i];maxScoreIdx = i;}}String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];// showing className on UITextView textView = findViewById(R.id.text);textView.setText(className);
判断最高分数,并将结果显示到textView中。
Demo2 PytorchDemoApp
这是另一个Demo App,它可以进行图像分类和文字分类。而图像分类就需要利用摄像头。
摄像头API通过使用 org.pytorch.demo.vision.AbstractCameraXActivity
类。
在 AbstractCameraXActivity.java 中的具体源码如下:
private void setupCameraX() {final TextureView textureView = getCameraPreviewTextureView();// 实现摄像头预览final PreviewConfig previewConfig = new PreviewConfig.Builder().build();final Preview preview = new Preview(previewConfig);preview.setOnPreviewOutputUpdateListener(output -> textureView.setSurfaceTexture(output.getSurfaceTexture()));// 实现数据分析并回调final ImageAnalysisConfig imageAnalysisConfig =new ImageAnalysisConfig.Builder().setTargetResolution(new Size(224, 224)).setCallbackHandler(mBackgroundHandler).setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE).build();final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);imageAnalysis.setAnalyzer((image, rotationDegrees) -> {if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {return;}final R result = analyzeImage(image, rotationDegrees);if (result != null) {mLastAnalysisResultTime = SystemClock.elapsedRealtime();runOnUiThread(() -> applyToUiAnalyzeImageResult(result));}});CameraX.bindToLifecycle(this, preview, imageAnalysis);}// analyzeImage函数是用来处理摄像头输出void analyzeImage(android.media.Image, int rotationDegrees)
而在 ImageClassificationActivity.java 中的源码如下:
protected AnalysisResult analyzeImage(ImageProxy image, int rotationDegrees) {if (mAnalyzeImageErrorState) {return null;}try {if (mModule == null) {final String moduleFileAbsoluteFilePath = new File(Utils.assetFilePath(this, getModuleAssetName())).getAbsolutePath();// 导入模型mModule = Module.load(moduleFileAbsoluteFilePath);mInputTensorBuffer =Tensor.allocateFloatBuffer(3 * INPUT_TENSOR_WIDTH * INPUT_TENSOR_HEIGHT);mInputTensor = Tensor.fromBlob(mInputTensorBuffer, new long[]{1, 3, INPUT_TENSOR_HEIGHT, INPUT_TENSOR_WIDTH});}final long startTime = SystemClock.elapsedRealtime();// 将以YUV420形式的Image类型转化为输入TensorTensorImageUtils.imageYUV420CenterCropToFloatBuffer(image.getImage(), rotationDegrees,INPUT_TENSOR_WIDTH, INPUT_TENSOR_HEIGHT,TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,TensorImageUtils.TORCHVISION_NORM_STD_RGB,mInputTensorBuffer, 0);final long moduleForwardStartTime = SystemClock.elapsedRealtime();// 利用模型进行运算final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;// 从模型中得到预测分数final float[] scores = outputTensor.getDataAsFloatArray();// 找到得分最高的前k个类final int[] ixs = Utils.topK(scores, TOP_K);final String[] topKClassNames = new String[TOP_K];final float[] topKScores = new float[TOP_K];for (int i = 0; i < TOP_K; i++) {final int ix = ixs[i];topKClassNames[i] = Constants.IMAGENET_CLASSES[ix];topKScores[i] = scores[ix];}final long analysisDuration = SystemClock.elapsedRealtime() - startTime;return new AnalysisResult(topKClassNames, topKScores, moduleForwardDuration, analysisDuration);} catch (Exception e) {Log.e(Constants.TAG, "Error during image analysis", e);mAnalyzeImageErrorState = true;runOnUiThread(() -> {if (!isFinishing()) {showErrorDialog(v -> ImageClassificationActivity.this.finish());}});return null;}}
最后将得到的前k个类加载到UI上。
protected void applyToUiAnalyzeImageResult(AnalysisResult result) {mMovingAvgSum += result.moduleForwardDuration;mMovingAvgQueue.add(result.moduleForwardDuration);if (mMovingAvgQueue.size() > MOVING_AVG_PERIOD) {mMovingAvgSum -= mMovingAvgQueue.remove();}for (int i = 0; i < TOP_K; i++) {final ResultRowView rowView = mResultRowViews[i];rowView.nameTextView.setText(result.topNClassNames[i]);rowView.scoreTextView.setText(String.format(Locale.US, SCORES_FORMAT,result.topNScores[i]));rowView.setProgressState(false);}mMsText.setText(String.format(Locale.US, FORMAT_MS, result.moduleForwardDuration));if (mMsText.getVisibility() != View.VISIBLE) {mMsText.setVisibility(View.VISIBLE);}mFpsText.setText(String.format(Locale.US, FORMAT_FPS, (1000.f / result.analysisDuration)));if (mFpsText.getVisibility() != View.VISIBLE) {mFpsText.setVisibility(View.VISIBLE);}if (mMovingAvgQueue.size() == MOVING_AVG_PERIOD) {float avgMs = (float) mMovingAvgSum / MOVING_AVG_PERIOD;mMsAvgText.setText(String.format(Locale.US, FORMAT_AVG_MS, avgMs));if (mMsAvgText.getVisibility() != View.VISIBLE) {mMsAvgText.setVisibility(View.VISIBLE);}}}
最后显示结果如下:
Pytorch Mobile 之Android Demo源码分析相关推荐
- Android HandlerThread 源码分析
HandlerThread 简介: 我们知道Thread线程是一次性消费品,当Thread线程执行完一个耗时的任务之后,线程就会被自动销毁了.如果此时我们又有一 个耗时任务需要执行,我们不得不重新创建 ...
- Android ADB 源码分析(三)
前言 之前分析的两篇文章 Android Adb 源码分析(一) 嵌入式Linux:Android root破解原理(二) 写完之后,都没有写到相关的实现代码,这篇文章写下ADB的通信流程的一些细节 ...
- 【Android SDM660源码分析】- 02 - UEFI XBL QcomChargerApp充电流程代码分析
[Android SDM660源码分析]- 02 - UEFI XBL QcomChargerApp充电流程代码分析 一.加载 UEFI 默认应用程序 1.1 LaunchDefaultBDSApps ...
- 【Android SDM660源码分析】- 03 - UEFI XBL GraphicsOutput BMP图片显示流程
[Android SDM660源码分析]- 03 - UEFI XBL GraphicsOutput BMP图片显示流程 1. GraphicsOutput.h 2. 显示驱动初化 DisplayDx ...
- 【Android SDM660源码分析】- 01 - 如何创建 UEFI XBL Protocol DXE_DRIVER 驱动及UEFI_APPLICATION 应用程序
[Android SDM660源码分析]- 01 - 如何创建 UEFI XBL Protocol DXE_DRIVER 驱动及UEFI_APPLICATION 应用程序 一.创建DXE_DRIVER ...
- 【Android SDM660源码分析】- 04 - UEFI ABL LinuxLoader 代码分析
[Android SDM660源码分析]- 04 - UEFI ABL LinuxLoader 代码分析 1. LinuxLoader.c 系列文章: <[Android SDM660开机流程] ...
- Android 音频源码分析——AndroidRecord录音(一)
Android 音频源码分析--AndroidRecord录音(一) Android 音频源码分析--AndroidRecord录音(二) Android 音频源码分析--AndroidRecord音 ...
- Android框架源码分析——从设计模式角度看 Retrofit 核心源码
Android框架源码分析--从设计模式角度看 Retrofit 核心源码 Retrofit中用到了许多常见的设计模式:代理模式.外观模式.构建者模式等.我们将从这三种设计模式入手,分析 Retrof ...
- 人人网官方Android客户端源码分析(1)
ContentProvider是不同应用程序之间进行数据交换的标准API,ContentProvider以某种Uri的形式对外提供数据,允许其他应用访问或修改数据;其他应用程序使用ContentRes ...
最新文章
- 看完这些能控制大脑的寄生虫,你会怀疑人类!
- kafka多分区只有一个在消费_kafka多个消费者只有一个消费
- GNN笔记:傅里叶变换
- 高性能mysql看不懂_高性能mysql笔记1
- python中a%b_Python中的a+=b和a=a+b之间的区别是什么?
- C++之++操作符重载
- pthread线程传递数据回主线程_操作系统4:线程(1)
- 大数据如何应用在企业人力资源管理
- uds 诊断协议的bootloader开发
- Excel数据可视化表盘模板
- oracle的临时表
- 卸载抖音和微博的一天……
- 看好699指纹手机暴露任泉的商业野心
- AIDE手机编程初级教程(零基础向) 3.2.1 设计欢迎页(主体)
- 【Seen看世界】:像高智商人群看齐
- 2018计算机网络MOOC第一章作业1
- [LiteratureReview]PointNet Deep Learning on Point Sets for 3D Classification and Segmentation
- 使用nvml获取n卡温度
- 【词目】:勿谓言之不预也
- 【数据挖掘之关联规则实战】关联规则智能推荐算法
热门文章
- 小米无线路由器 + u盘
- 区块链益智冒险游戏NFT游戏开发
- mac 版微信音频设备启动失败
- 热门智能手表!OPPO华为苹果强强对决,哪款值得买?
- Spring-Data-Redis--解决java.lang.ClassCastException: java.util.LinkedHashMap cannot be cast to xxx
- javascript字符串拼接引号转义
- PVE安装Openwrt 旁路由
- php写动物的属性,写动物作文技巧与方法
- 计算机应用怎么样多项选择,【2017年整理】12春学期《计算机应用基础》在线作业第三次多项选择.doc...
- 计算机作业秋天的路,18秋西南大学《计算机基础》在线作业.doc