Android配置tensorflow lite

按照官方网站的指导在项目的模块的构建文件build.gradle中配置中增加如下配置:

 implementation 'org.tensorflow:tensorflow-lite:2.7.0'implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
android{aaptOptions {noCompress "tflite"}defaultConfig {ndk {abiFilters 'armeabi-v7a', 'arm64-v8a'}}}

导入模型资源资源

创建将文《关于将Tesorflow的SavedModel模型转换成tflite模型》创建的模型model.tflite,导入到Android项目的assets目录中。

定义模型基本配置类BaseModelConfig

/*** 定义模型的基本配置类*/
public abstract class BaseModelConfig{//每通道处理的字节数var numBytesPerChannel:Int = 0//定义批处理的个数var dimBatchSize:Int = 0//定义像素个数var dimPixelSize:Int = 0//定义图片的宽度var dimImgWidth:Int = 0//定义图片的高度var dimImgHeight:Int = 0//定义平均差var imageMean=0//定义图片的标准差var imageSTD:Float = 0.0F//定义模型的名称lateinit var modelName:Stringconstructor() : super() {setConfigs()}/*** 将像素值转换成ByteBuffer* 增加图片的值*/public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)/*** 配置*/public abstract fun setConfigs()
}

定义FloatSavedModelConfig类

class FloatSavedModelConfig: BaseModelConfig() {public override fun setConfigs() {modelName="model.tflite"numBytesPerChannel = 4dimBatchSize = 1dimPixelSize = 1dimImgWidth = 28dimImgHeight = 28imageMean = 0imageSTD = 255.0f}override fun addImgValue(imgData: ByteBuffer, pixel: Int) {imgData.putFloat(((pixel  and 0xFF) - imageMean) / imageSTD)}
}

创建配置模型参数的工厂类

object ModelConfigFactory {const val FLOAT_SAVED_MODEL = "float_saved_model"const val QUANT_SAVED_MODEL = "quant_saved_model"fun getModelConfig(model: String): BaseModelConfig? =when(model) {FLOAT_SAVED_MODEL-> FloatSavedModelConfig()QUANT_SAVED_MODEL-> QuantSavedModelConfig()else->null}
}

定义图像分类器

class ImageClassifier {private val TAG = "FashionMNIST"private val RESULTS_TO_SHOW = 3lateinit var mTFLite: Interpreterlateinit var mModelPath:Stringvar mNumBytesPerChannel = 0var mDimBatchSize = 0var mDimPixelSize = 0var mDimImgWidth = 0var mDimImgHeight = 0lateinit var mModelConfig:BaseModelConfig//定义标签检测的二维数组1x10val mLabelProbArray = Array(1) {FloatArray(10)}val labels = arrayListOf("T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子")//定义检测结果保持到优先队列中var mSortedLabels = PriorityQueue<Map.Entry<String, Float>>(RESULTS_TO_SHOW) {o1, o2 -> o1?.value!!.compareTo(o2?.value!!)}/*** 配置参数*/private fun initConfig(config: BaseModelConfig) {mModelConfig = configmNumBytesPerChannel = config.numBytesPerChannelmDimBatchSize = config.dimBatchSizemDimPixelSize = config.dimPixelSizemDimImgWidth = config.dimImgWidthmDimImgHeight = config.dimImgHeightmModelPath = config.modelName}constructor(modelConfig: String, activity: Activity) {// 初始化分类器的相关参数initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)// 使用配置参数初始化翻译器mTFLite = Interpreter(loadModelFile(activity)!!)}/*** 在Assets中的模型文件映射到内存中* */private fun loadModelFile(activity: Activity): MappedByteBuffer? {val fileDescriptor = activity.assets.openFd(mModelPath)val inputStream = FileInputStream(fileDescriptor.fileDescriptor)val fileChannel = inputStream.channelval startOffset = fileDescriptor.startOffsetval declaredLength = fileDescriptor.declaredLengthreturn fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)}/*** 将图片数据写入到ByteBuffer,加载到内存中* */protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {val intValues = IntArray(mDimImgWidth * mDimImgHeight)//调整要处理的图片为28x28var tmp = scaleBitmap(bitmap)//将图片二值化tmp = binarized(tmp)//将二值化的图片加载到内存中tmp.getPixels(intValues,0, tmp.width, 0, 0, tmp.width, tmp.height)val imgData = ByteBuffer.allocateDirect(mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize)imgData.order(ByteOrder.nativeOrder())imgData.rewind()//将图片转换成像素实数数据var pixel = 0for (i in 0 until mDimImgWidth) {for (j in 0 until mDimImgHeight) {var value = intValues[pixel++]mModelConfig.addImgValue(imgData, value)}}return imgData}/*** 将图片二值化处理* 转换成二值图像* @param bmp* @return*/fun binarized(bmp: Bitmap): Bitmap {val width = bmp.widthval height = bmp.heightval pixels = IntArray(width * height)//将图片的像素加载到数组中bmp.getPixels(pixels, 0, width, 0, 0, width, height)var alpha = 0xFF shl 24for (i in 0 until height) {for (j in 0 until width) {val grey = pixels[width * i + j]// 分离三原色alpha = grey and -0x1000000 shr 24var red = grey and 0x00FF0000 shr 16var green = grey and 0x0000FF00 shr 8var blue = grey and 0x000000FFval tmp = 180red = if (red > tmp) 255 else 0blue = if (blue > tmp) 255 else 0green = if (green > tmp) 255 else 0pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blueif (pixels[width * i + j] == -1) {pixels[width * i + j] = -1} else {pixels[width * i + j] = -16777216}}}// 新建图片val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)// 设置图片数据newBmp.setPixels(pixels, 0, width, 0, 0, width, height)return newBmp}/*** 将图片调整到规定的大小28x28*/fun scaleBitmap(bmp: Bitmap?): Bitmap {return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)}/*** 分类处理*/fun doClassify(bitmap: Bitmap?): String? {// 将Bitmap图片转换成TFLite翻译器的可读的ByteBufferval imgData = convertBitmapToByteBuffer(bitmap)// do run interpreterval startTime = System.nanoTime()mTFLite.run(imgData, mLabelProbArray)val endTime = System.nanoTime()Log.i(TAG, String.format("运行识别的时间: %f ms",(endTime - startTime).toFloat() / 1000000.0f))// 生成并返回结果return printTopKLabels()}/*** 打印检测排序在前几位的标签,并作为结果显示在UI界面中。*/fun printTopKLabels(): String? {for (i in 0..9) {mSortedLabels.add(AbstractMap.SimpleEntry(labels[i],mLabelProbArray[0][i]))if (mSortedLabels.size > RESULTS_TO_SHOW) {mSortedLabels.poll()}}val textToShow = StringBuffer()val size = mSortedLabels.sizefor (i in 0 until size) {val label = mSortedLabels.poll()textToShow.insert(0, String.format("\n%s   %4.8f", label.key, label.value))}return textToShow.toString()}}

定义主活动MainActivity

在主活动中,主要处理如下操作:
(1)从图库中选择图片
(2)利用图像分类器检测图片中的内容,判断是FashionMnist数据集的哪种标签
(3)将检测的结果在移动终端的GUI界面中显示出来。

class MainActivity : AppCompatActivity() {private lateinit var binding: ActivityMainBindingval RequestCameraCode = 1val TAG = "FashionMNIST"companion object{var mIsFloat = true}private var bitmap: Bitmap? = nulloverride fun onCreate(savedInstanceState: Bundle?) {super.onCreate(savedInstanceState)//生成视图绑定对象binding = ActivityMainBinding.inflate(layoutInflater)//设置视图的根视图setContentView(binding.root)binding.imageView.setOnClickListener {val intent = Intent()intent.type = "image/*"intent.action = Intent.ACTION_GET_CONTENTstartActivityForResult(intent,RequestCameraCode)}val spinnerAdapter = ArrayAdapter<String>(this,android.R.layout.simple_spinner_item,getChoices())binding.typeSpinner.adapter = spinnerAdapterbinding.typeSpinner.onItemSelectedListener = object : OnItemSelectedListener {override fun onItemSelected(parent: AdapterView<*>?,view: View,position: Int,id: Long) {mIsFloat = position == 0}override fun onNothingSelected(parent: AdapterView<*>?) {}}}override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {super.onActivityResult(requestCode, resultCode, data)if(resultCode == RESULT_OK && requestCode == RequestCameraCode){val uri = data?.datatry{//从图库中读取图片var bitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(uri!!))//在图像视图ImageView中显示图片binding.imageView.setImageBitmap(bitmap)//判断模型类型val config = when(mIsFloat){true->ModelConfigFactory.FLOAT_SAVED_MODELelse->ModelConfigFactory.QUANT_SAVED_MODEL}//根据模型类型创建图像识别器val classifier = ImageClassifier(config,this)//检测并判断图像的类别val result = classifier.doClassify(bitmap)binding.labelTxt.text = resultbinding.tipTxt.visibility = View.GONE}catch(e: FileNotFoundException){Log.d(TAG,"没有找到指定的图像文件")}catch(e: IOException){Log.e(TAG,"初始化图像识别器失败")}}}/*** 返回可用模型的名称*/private fun getChoices()= resources.getStringArray(R.array.model_names)}

参考文献

李锡涵等 《简明的Tensorflow 2》人民邮电出版社 北京 P91-P96

面向Android的开发基于Tensorflow Lite框架深度学习的应用(一)相关推荐

  1. ElasticDL:首个基于 TensorFlow 实现弹性深度学习的开源系统

    9 月 11 日,蚂蚁金服开源了 ElasticDL 项目,据悉这是业界首个基于 TensorFlow 实现弹性深度学习的开源系统. Google Brain 成员 Martin Wicke 此前在公 ...

  2. 基于TensorFlow Serving的深度学习在线预估

    一.前言 随着深度学习在图像.语言.广告点击率预估等各个领域不断发展,很多团队开始探索深度学习技术在业务层面的实践与应用.而在广告CTR预估方面,新模型也是层出不穷: Wide and Deep[1] ...

  3. 特斯拉如何开发基于纯视觉的深度学习系统

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨初光 来源丨糖果Autosar 打造全自动驾驶汽车所需的技术堆栈是什么?公司和研究人员对该问题的 ...

  4. 深度学习入门笔记系列 ( 二 )——基于 tensorflow 的一些深度学习基础知识

    本系列将分为 8 篇 .今天是第二篇 .主要讲讲 TensorFlow 框架的特点和此系列笔记中涉及到的入门概念 . 1.Tensor .Flow .Session .Graphs TensorFlo ...

  5. 基于TensorFlow Lite实现的Android花卉识别应用

    介绍 本教程将在Android设备上使用TensorFlow Lite运行图像识别模型,具体包括: 使用TensorFlow Lite Model Maker训练自定义的图像分类器 利用Android ...

  6. 基于TensorFlow Lite的人声识别在端上的实现

    通过TensorFlow Lite,移动终端.IoT设备可以在端上实现声音识别,这可以应用在安防.医疗监护等领域.来自阿里巴巴闲鱼技术互动组仝辉和上叶通过TensorFlow Lite实现了一套完整的 ...

  7. 面向 Android 软件开发套件(SDK)的 x86 Android* 系统映像许可协议

    英特尔公司面向 Android 软件开发套件(SDK)的 x86 Android* 系统映像的内部评估许可协议 此<内部评估许可协议>(以下简称"协议") 的订立双方为 ...

  8. 英特尔公司面向 Android 软件开发套件(SDK)4.3 的 x86 Android* 系统映像的内部评估许可协议...

    此<内部评估许可协议>(以下简称"协议") 的订立双方为英特尔与贵方(作为开发人员个人或法律实体 - 下文认定为"接收方"). 英特尔应根据< ...

  9. 基于OpenGL ES 的深度学习框架编写

    基于OpenGL ES的深度学习框架编写 背景与工程定位 背景 项目组基于深度学习实现了视频风格化和人像抠图的功能,但这是在PC/服务端上跑的,现在需要移植到移动端,因此需要一个移动端的深度学习的计算 ...

最新文章

  1. 软件测试技术---黑盒测试
  2. 怎么把数据存到MySQL_怎样将Arduino数据直接存储到MySQL
  3. 普通高中生水平就能干好的编程到底是不是高科技?
  4. hdu 6153 A Secret kmp + dp
  5. 物联网学习之路——物联网通信技术:NBIoT
  6. View Merge 在安全控制上的变化,是 BUG 还是增强 ?
  7. 实战快速恢复Exchange 2010误删除的邮箱
  8. sql 视图不排序_算法工程师SQL进阶:神奇的自连接与子查询
  9. 用户如何设置浏览器主页的历史记录和管理加载项
  10. 重心解模糊化matlab,谁能给我个用重心法的MATLAB模糊推理程序
  11. Session的活化与钝化
  12. 广州 人才引进,家属随迁(有小孩),自己房产,外省户口,复核所需资料
  13. 定制 Windows 10 安装程序
  14. 什么样的自学Java网站才适合学习者?
  15. 爬虫-大学教务系统选修课抢课
  16. elasticsearch插件一——-head插件安装详解
  17. R12 AR INVOICE 接口表导入
  18. Android 字体库详解
  19. Autodesk 3dsMax2022 安装说明
  20. 想进入IT行业,该从哪里开始学习

热门文章

  1. JAVA//JAVA基本程序设计架构
  2. 后端程序员要会linux吗,后端程序员必备的Linux基础知识
  3. 使用expdp导出数据
  4. Android之获取移动网络ip
  5. CEYE平台使用简介
  6. form表单提交和ajax表单提交
  7. 数据库简单sql语句(CURD)
  8. webgl_图形变换(旋转,平移,缩放)
  9. java utf8 简繁转换 类库,java 中文繁简体转换工具 opencc4j
  10. 再见 xxl-job!更强大的新一代分布式任务调度框架来了