基于Pytorch Mobile在安卓手机端部署深度估计模型
基于Pytorch Mobile在安卓手机端部署深度估计模型
- 1.选取torch版本的深度估计模型
- 2.修改模型实现代码
- 3.Pytorch生成ptl模型
- 4.安卓端部署代码
- 5.实验配置
- 6.手机端效果展示
1.选取torch版本的深度估计模型
深度估计模型这里选择torch版本的Monodepth,代码地址:https://github.com/OniroAI/MonoDepth-PyTorch,文章链接:https://arxiv.org/abs/1609.03677。
建议在实现本文之前,先跑通torch的官方教程,https://github.com/pytorch/android-demo-app,本文建立在能跑通示例中语义分割模型的基础上。
Monodepth代码中需要使用的部分:
2.修改模型实现代码
整个网络设计中只使用pytorch定义的方法或python原生的语法,不能使用其他第三方框架如Numpy,Opencv。该例中,模型定义在models_resnet.py中,以Resnet18_md为例,需要修改的部分为:
1.代码中使用numpy实现的操作用原生的python库进行代替:
class conv(nn.Module):def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):super(conv, self).__init__()self.kernel_size = kernel_sizeself.conv_base = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride)self.normalize = nn.BatchNorm2d(num_out_layers)def forward(self, x):p = int(np.floor((self.kernel_size-1)/2)) #使用Numpy实现需要修改 np.floor ==》 math.floor ,即int(math.floor((self.kernel_size-1)/2))p2d = (p, p, p, p)x = self.conv_base(F.pad(x, p2d))x = self.normalize(x)return F.elu(x, inplace=True)
class maxpool(nn.Module):def __init__(self, kernel_size):super(maxpool, self).__init__()self.kernel_size = kernel_sizedef forward(self, x):p = int(np.floor((self.kernel_size-1) / 2)) #使用Numpy实现需要修改 np.floor ==》 math.floorp2d = (p, p, p, p)return F.max_pool2d(F.pad(x, p2d), self.kernel_size, stride=2)
2.代码中过期的pytorch函数重新实现,因为Pytorch Mobile需要的pytorch版本很新,因此有些旧的实现已经在新版本中被修改:
self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=2, mode='bilinear', align_corners=True)
修改为:
udisp4 = nn.functional.interpolate(disp4, scale_factor=2., mode='bilinear', align_corners=True)
scale_factor在新版本中只能是浮点数,并且udisp4和disp4并没有在__init__()中定义为属性,因此在这里去掉self,否则Pytorch Mobile编译会报错。
3.修改输出:
return self.disp1, self.disp2, self.disp3, self.disp4
修改为
return disp1
这里输出四个视差是为了在多尺度下做Loss,在迁移到手机上时我们只用选择最大的尺度输出即可。
3.Pytorch生成ptl模型
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.mobile_optimizer import optimize_for_mobile
from utils import get_modelimage = Image.open("O:\\xxx\\0.jpg") #读取一张图片用来测试输出尺寸是否满足预期
preprocess = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input = preprocess(image) #转换成Tensor
model = get_model('resnet18_md', 3, True) #获取模型,模型的定义代码为models_resnet.py。
model.load_state_dict(torch.load("O:\\xxx\\monodepth_resnet18_001.pth")) #读取模型的预训练参数,预训练文件下载地址https://github.com/OniroAI/MonoDepth-PyTorch
input = input.unsqueeze(0)
output = model(input) #使用模型处理一张图片
print(output.shape) #测试尺度是否正常
model.eval()scripted_module = torch.jit.script(model) #模型的转换!!!此处是重点,转换后的ptl模型就可以在安卓端运行
optimized_scripted_module = optimize_for_mobile(scripted_module) #针对移动端的特殊优化可以加快推理速度# Export full jit version model (not compatible with lite interpreter)
scripted_module.save("monodepth.pt")
# Export lite interpreter version model (compatible with lite interpreter)
scripted_module._save_for_lite_interpreter("monodepth_scripted.ptl")
# using optimized lite interpreter model makes inference about 60% faster than the non-optimized lite interpreter model, which is about 6% faster than the non-optimized full jit model
optimized_scripted_module._save_for_lite_interpreter("monodepth_scripted_optimized.ptl") #根据官网描述,这种方式得到的模型推理速度最快比monodepth.pt快60%比monodepth_scripted.ptl快6%
4.安卓端部署代码
package org.pytorch.imagesegmentation;import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.os.SystemClock;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.ProgressBar;import androidx.appcompat.app.AppCompatActivity;import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;public class MainActivity extends AppCompatActivity implements Runnable {private ImageView mImageView;private Button mButtonSegment;private ProgressBar mProgressBar;private Bitmap mBitmap = null;private Module mModule = null;private int mImagename = 0;public static String assetFilePath(Context context, String assetName) throws IOException {File file = new File(context.getFilesDir(), assetName);if (file.exists() && file.length() > 0) {file.delete();}try (InputStream is = context.getAssets().open(assetName)) {try (OutputStream os = new FileOutputStream(file)) {byte[] buffer = new byte[4 * 1024];int read;while ((read = is.read(buffer)) != -1) {os.write(buffer, 0, read);}os.flush();}return file.getAbsolutePath();}}@Overrideprotected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);setContentView(R.layout.activity_main);try {mBitmap = BitmapFactory.decodeStream(getAssets().open(mImagename + ".jpg"));} catch (IOException e) {Log.e("DepthEstimation", "Error reading assets", e);finish();}mImageView = findViewById(R.id.imageView);mImageView.setImageBitmap(mBitmap);final Button buttonRestart = findViewById(R.id.restartButton);buttonRestart.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {if(mImagename > 8) mImagename = 0;try {mBitmap = BitmapFactory.decodeStream(getAssets().open(mImagename + ".jpg"));mImagename++;mImageView.setImageBitmap(mBitmap);} catch (IOException e) {Log.e("DepthEstimation", "Error reading assets", e);finish();}}});mButtonSegment = findViewById(R.id.segmentButton);mProgressBar = (ProgressBar) findViewById(R.id.progressBar);mButtonSegment.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {mButtonSegment.setEnabled(false);mProgressBar.setVisibility(ProgressBar.VISIBLE);mButtonSegment.setText(getString(R.string.run_model));Thread thread = new Thread(MainActivity.this);thread.start();}});try {mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "monodepth_scripted_optimized.ptl"));} catch (IOException e) {Log.e("DepthEstimation", "Error reading assets", e);finish();}}@Overridepublic void run() {final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(mBitmap,TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);final float[] inputs = inputTensor.getDataAsFloatArray();final long startTime = SystemClock.elapsedRealtime();IValue outTensors = mModule.forward(IValue.from(inputTensor));final long inferenceTime = SystemClock.elapsedRealtime() - startTime;Log.d("DepthEstimation", "inference time (ms): " + inferenceTime);System.out.println(inferenceTime);final Tensor outputTensor = outTensors.toTensor();final float[] intValues = outputTensor.getDataAsFloatArray();int width = mBitmap.getWidth();int height = mBitmap.getHeight();ArrayList<Float> arralist = new ArrayList<>();for (int i = 0 ; i< intValues.length ; i++){arralist.add(intValues[i]);}final Bitmap bitmap = arrayFlotToBitmap(arralist, width, height);runOnUiThread(new Runnable() {@Overridepublic void run() {mImageView.setImageBitmap(bitmap);mButtonSegment.setEnabled(true);mButtonSegment.setText(getString(R.string.segment));mProgressBar.setVisibility(ProgressBar.INVISIBLE);}});}private static Bitmap arrayFlotToBitmap(List<Float> floatArray, int width, int height){byte alpha = (byte) 255 ;Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;float Maximum = Collections.max(floatArray);float minmum = Collections.min(floatArray);float delta = Maximum - minmum ;int i = 0 ;for (float value : floatArray){byte temValue = (byte) ((byte) ((((value-minmum)/delta)*255)));byteBuffer.put(4*i, temValue) ;byteBuffer.put(4*i+1, temValue) ;byteBuffer.put(4*i+2, temValue) ;byteBuffer.put(4*i+3, alpha) ;i++ ;}bmp.copyPixelsFromBuffer(byteBuffer) ;return bmp ;}}
实现参考了pytorch官方的语义分割实例,https://github.com/pytorch/android-demo-app/tree/master/ImageSegmentation。其中有两个主要修改:
public static String assetFilePath(Context context, String assetName) throws IOException {File file = new File(context.getFilesDir(), assetName);if (file.exists() && file.length() > 0) {file.delete();//return file.getAbsolutePath(); //这行需要删除,因为只要读取一次模型就会在当前系统中将该模型暂存,//如果你第一次加载的模型有问题,之后没有都会加载有问题的模型,应该将模型缓存的模型删除。 }try (InputStream is = context.getAssets().open(assetName)) {try (OutputStream os = new FileOutputStream(file)) {byte[] buffer = new byte[4 * 1024];int read;while ((read = is.read(buffer)) != -1) {os.write(buffer, 0, read);}os.flush();}return file.getAbsolutePath();}}
private static Bitmap arrayFlotToBitmap(List<Float> floatArray, int width, int height){byte alpha = (byte) 255 ;Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;float Maximum = Collections.max(floatArray);float minmum = Collections.min(floatArray);float delta = Maximum - minmum ;int i = 0 ;for (float value : floatArray){byte temValue = (byte) ((byte) ((((value-minmum)/delta)*255)));byteBuffer.put(4*i, temValue) ;byteBuffer.put(4*i+1, temValue) ;byteBuffer.put(4*i+2, temValue) ;byteBuffer.put(4*i+3, alpha) ;i++ ;}bmp.copyPixelsFromBuffer(byteBuffer) ;return bmp ;}
这是一个将float数组转换为Bitmap的函数,截取自pytorch官方的issue里,https://github.com/pytorch/pytorch/issues/30655
5.实验配置
- pytorch=1.10.0
- android studio Arctic Fox(2020.3.1 Patch 4)
- 手机型号:VivoX60tPro+(android 11) (只要满足build.gradle里要求的最低安卓版本应该都可以跑通,用小米9也成功部署)
6.手机端效果展示
看到最后如果还是觉得不够详细的话,可以回复我,考虑在B站上传完整的部署视频。
基于Pytorch Mobile在安卓手机端部署深度估计模型相关推荐
- 手机office办公——微软推出安卓手机端Office Mobile应用
微软于4月19日在中国北京首次发布全新的安卓手机端Office Mobile应用.微软此次将 Word.Excel.PowerPoint三者完美合一,为中国的消费者带来完整的移动办公体验. 微软大中华 ...
- pdf文件如何在安卓手机端不用下载在线预览
由于H5手机端页面,苹果ios手机端支持在线预览,而安卓手机端不行. 解决方案: 使用pdf.js插件. 官网地址:https://mozilla.github.io/pdf.js/ 第一步:下载整个 ...
- 【uni-app】什么是uni-app?如何进行开发?如何连接微信开发者工具与安卓手机端?
文章目录 1 什么是 uni-app 2 如何使用与开发 3 用 HBuilderX 创建项目 4 如何连接微信开发者工具 5 如何连接安卓手机端 1 什么是 uni-app (此部分内容出自 uni ...
- 修改电量android,安卓手机端修改电池电量图标的教程
安卓手机端是可以给电池电量图标进行修改的,恐怕大家都不知道吧,不同的系统基本都有比一样的电量图标,这就导致了不是每一个人都喜欢同一个图标,现在我就来为大家讲解如何在手机端修改电量图标的教程. 第一步: ...
- 微信功能版(可用于电脑、安卓手机端)微信电脑版 使用说明
好外号外!!!!本文是微信功能版(可用于电脑.安卓手机端)使用说明[转自网络,本文仅限于个人交流] 经过开发,此版本能够自行设置当前地址位置,可以随意改变地点,方便你有针对性的找任何地区的周边好友,交 ...
- android毕业设计——基于Android+Java+Python的手机端办公自动化OA系统设计与实现(毕业论文+程序源码)——办公自动化OA系统
基于Android+Java+Python的手机端办公自动化OA系统设计与实现(毕业论文+程序源码) 大家好,今天给大家介绍基于Android+Java+Python的手机端办公自动化OA系统设计与实 ...
- 安卓手机端网页,开启输入法时页面内容被压缩的解决方法
在安卓手机端的网页中,打开输入框会使页面的整体内容压缩(可能因为我使用了百分比的布局) 而在iphone浏览网页时则不会出现这种问题. 对于这种情况需要添加部分js代码来防止页面的压缩. var h1 ...
- 安卓手机端SSH远程连接云服务器
安卓手机端SSH远程连接云服务器 使用软件:JuiceSSH(在手机应用市场搜索即可安装) 安装完成后,直接运行 选择快速创建连接 输入目的主机IP和用户名进行远程连接 这里没有输入密码,肯定认证失败 ...
- Wifi热点创建工具配合电脑与安卓手机端实现秒传文件的方法
有什么方法能在有Wifi网或无Wifi网的环境,但没有数据线如何用笔记本或带无线网卡的台式电脑更方便的给安卓手机传送文件呢?我想到之前电脑之间都是拿飞鸽传书传文件,有没有安卓版的?一找,还真有,哈~~ ...
最新文章
- 华为手机像素密度排行_2020拍照手机十大排行:华为128分破纪录,苹果无一上榜...
- 不做“浮冰”,深挖AI技术和场景
- python数据结构与算法(11)
- SVN服务器从Windows迁移到Linux
- 浅谈Android系统进程间通信(IPC)机制Binder中的Server和Client获得Service Manager接口之路
- 先序,中序,后序线索二叉树
- 前端笔记-jquery
- Swift傻傻分不清楚系列(一)常量与变量
- git之you can‘t overwrite the remote branch问题解决
- 上机练习 实现消费单的打印 需求不明确要补充
- 字符串 不是有效的 AllXsd 值。
- Microsoft PetShop 3.0 设计与实现 分析报告―――数据访问层
- 【Kafka】Kafka Record for partition topic at offset xx is invalid, cause: Record corrupt
- 荣耀鸿蒙系统内测,官宣!荣耀 Magic UI 4.0 与 EMUI 11 同步内测:后续支持升级为鸿蒙操作系统...
- 18位身份证验证(Java)
- Mac上qmc0文件转码为mp3
- Pspice for TI取消默认打开方式
- 一个UE4崩溃问题以及解决方案
- 移动应用的引导模式设计
- 外贸客户邮箱用什么?外贸哪个邮箱好?
热门文章
- java 登录验证码_java实现登录验证码
- 一个“精神病”人的世界观——我看完了,然后陷入深深的不安中……
- 开源BI平台软件特性对比
- Python——读取xlsx格式的Excel表格
- 基于ArduinoNano的LED点阵时钟探索(1)四合一MAX7219+DS3231
- 后端常用数据库的使用MongoDB, Redis, Mysql
- 天龙八部为什么得到角色信息失败 服务器繁忙《302》,每日最大化获取活跃值的方法分享:卡到499点是关键...
- MATLAB中的impixel函数——获取图像像素值
- FileZilla搭建FTP服务器图解教程,并允许外网访问NAT内网
- 【C++ 科学计算】矩阵元素绝对值小于设定值时,元素值变为零