文章目录

  • 1 手写体基础知识
    • 1.1 探索MINIST数据集
    • 1.2 CNN基本介绍
    • 1.3 基于TensorFlow 的手写体识别
  • 2 在Android实现手写体识别
    • 2.1 加载模型
    • 2.2 自定义写画View
    • 2.3 将bitmap转成网络需要的格式
    • 2.4 识别结果的输出
  • 3 总结

今天有个网友在手把手教你在Android上搭建tensorflow Lite2.0这篇文章下评论

求问如何进行一个图像的输入和数组的输出?

我想这也是很多初学者的痛点,很多入门同学都没有完整从模型建立,训练,到转换成TensorFlowLite,并在Android中实际的用。

于是我就把我之前写的demo给了他,想想还是抽空把这个demo写成文章,希望能够给帮助到更多的入门的同学。

虽然基于TensorFlow 实现手写体的文章,一抓一大把,但是我还是有必要啰嗦下,毕竟它是很好的入门人工智能的实例。

我不关注的手写体识别算法的细节,关注整个从模型到应用的整个过程,想对算法了解的,请自行学习。

有兴趣的同学可以关注下我的系列博客人工智能系列(更新中……),自己也在学习这方面的知识,一起学习和交流。

1 手写体基础知识

1.1 探索MINIST数据集

采用的MNIST数据集,它来自美国国家标准与技术研究所,National Institute of Standards and Technology(NIST)。 训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set)也是同样比例的手写数字数据。

数据集中每张图片是什么样的呢?

就张这样子:

通过下面代码获得:

# Plot ad hoc mnist instances
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt# load (downloaded if needed) the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# plot 4 images as gray scale
plt.subplot(221)
plt.imshow(X_train[0], cmap=plt.get_cmap("gray"))
plt.subplot(222)
plt.imshow(X_train[1], cmap=plt.get_cmap("gray"))
plt.subplot(223)
plt.imshow(X_train[2], cmap=plt.get_cmap("gray"))
plt.subplot(224)
plt.imshow(X_train[3], cmap=plt.get_cmap("gray"))
# show the plot
plt.show()

但是实际上存储是什么呢?

你可以发现这是一个0字,存储是0这张图片的RGB的值,凡是值为零的地方都是黑色,非零的地方都是不同灰阶。这就是一张图片灰阶RGB矩阵。

1.2 CNN基本介绍

本次采用手写体识别算法就是CNN(卷积神经网络),在计算机视觉中应用比较广泛。

最为经典的CNN手写体识别图,描述了手写体识别的整个过程,具体的细节就不讲了,有机会写一篇这个算法细节的文章,但是本文神经网络模型结构如下:

1.3 基于TensorFlow 的手写体识别

采用TensorFlow 中Keras接口,比较适合新手使用。让你感觉创建神经网络模型就像是搭积木一样。

代码如下,留意注释。

import numpy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.python.keras.utils import np_utils
import tensorflow as tf
import pathlib# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255print(X_train.shape)
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print(X_train[0])num_classes = y_test.shape[1]def baseline_model():# create modelmodel = Sequential()model.add(Conv2D(32, kernel_size=(5, 5),input_shape=(28, 28, 1),//采用单通道的图片activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.2))model.add(Flatten())model.add(Dense(128, activation='relu'))model.add(Dense(num_classes, activation='softmax'))# Compile modelmodel.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam',metrics=['accuracy'])return modelmodel = baseline_model()
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("CNN Error: %.2f%%" % (100 - scores[1] * 100))# 上面升级网络训练的过程
# 下面需要将其转换tensorflow Lite模型,便于在Android中使用。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)

2 在Android实现手写体识别

如果你不知道如何配置Android的环境,请参考手把手教你在Android上搭建tensorflow Lite2.0

2.1 加载模型

将训练好的TensorFlow Lite 文件放在Android的asset文件夹下。

public class TF {private static Context mContext;Interpreter mInterpreter;private static TF instance;public static TF newInstance(Context context) {mContext = context;if (instance == null) {instance = new TF();}return instance;}Interpreter get() {try {if (Objects.isNull(mInterpreter))mInterpreter = new Interpreter(loadModelFile(mContext));} catch (IOException e) {e.printStackTrace();}return mInterpreter;}// 获取文件private MappedByteBuffer loadModelFile(Context context) throws IOException {AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite");FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());FileChannel fileChannel = inputStream.getChannel();long startOffset = fileDescriptor.getStartOffset();long declaredLength = fileDescriptor.getDeclaredLength();return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);}
}

2.2 自定义写画View

public class HandWriteView extends View {Path mPath = new Path();Paint mPaint;Bitmap mBitmap;Canvas mCanvas;public HandWriteView(Context context) {super(context);init();}public HandWriteView(Context context, AttributeSet attrs) {super(context, attrs);init();}void init() {mPaint = new Paint();mPaint.setColor(Color.WHITE);mPaint.setStyle(Paint.Style.STROKE);mPaint.setStrokeJoin(Paint.Join.ROUND);mPaint.setStrokeCap(Paint.Cap.ROUND);mPaint.setStrokeWidth(30);}@Overrideprotected void onDraw(Canvas canvas) {super.onDraw(canvas);mBitmap = Bitmap.createBitmap(getWidth(), getHeight(), Bitmap.Config.ARGB_8888);mCanvas = new Canvas(mBitmap);mCanvas.drawColor(Color.BLACK);canvas.drawPath(mPath, mPaint);mCanvas.drawPath(mPath, mPaint);}@Overridepublic boolean onTouchEvent(MotionEvent event) {switch (event.getAction()) {case MotionEvent.ACTION_DOWN:mPath.moveTo(event.getX(), event.getY());break;case MotionEvent.ACTION_MOVE:mPath.lineTo(event.getX(), event.getY());break;case MotionEvent.ACTION_UP:case MotionEvent.ACTION_CANCEL:break;}postInvalidate();return true;}Bitmap getBitmap() {mPath.reset();return mBitmap;}
}

2.3 将bitmap转成网络需要的格式

因为数据集中的数据都是28 * 28 * 3的,28为图片的宽和高,3为R,G,B三个通道,所以在输入到网络之前,我们需要将bitmap转成网络需要的格式。

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {int inputShape[] = TF.newInstance(getApplicationContext()).get().getInputTensor(0).shape();int inputImageWidth = inputShape[1];int inputImageHeight = inputShape[2];Bitmap bs = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true);mImageView.setImageBitmap(bs);ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * inputImageHeight * inputImageWidth);byteBuffer.order(ByteOrder.nativeOrder());int[] pixels = new int[inputImageWidth * inputImageHeight];bs.getPixels(pixels, 0, bs.getWidth(), 0, 0, bs.getWidth(), bs.getHeight());for (int pixelValue : pixels) {int r = (pixelValue >> 16 & 0xFF);int g = (pixelValue >> 8 & 0xFF);int b = (pixelValue & 0xFF);// Convert RGB to grayscale and normalize pixel value to [0..1]float normalizedPixelValue = (r + g + b) / 3.0f / 255.0f;byteBuffer.putFloat(normalizedPixelValue);}return byteBuffer;}

2.4 识别结果的输出

识别的结果是根据0-9的概率进行判断,概率最大的就是识别的结果。

float[][] input = new float[1][10];
TF.newInstance(getApplicationContext()).get().run(convertBitmapToByteBuffer(mHandWriteView.getBitmap()), input);
int result = -1;
float value = 0f;
for (int j = 0; j < 10; j++) {if (input[0][j] > value) {value = input[0][j];result = j;}
Log.i("TAG", "result: " + j + " " + input[0][j]);
}
if (input[0][result] < 0.2f) {mTextView.setText("结果为:未识别");
} else {mTextView.setText("结果为:" + result);
}

识别结果:

若有需要,请自行点击demo下载。

3 总结

开发一个人工智能APP的主要流程就这么多,关键还是在于算法,要想得到更为精准的模型,除了要采用更好的模型之外,还需要对数据进行旋转,增强或者白质化,来提高数据的多样性。

欢迎大家一起交流!!!!

用TensorFlow Lite 写个手写体识别 APP相关推荐

  1. 深度学习笔记:Tensorflow手写mnist数字识别

    文章出处:深度学习笔记11:利用numpy搭建一个卷积神经网络 免费视频课程:Hellobi Live | 从数据分析师到机器学习(深度学习)工程师的进阶之路 上一讲笔者和大家一起学习了如何使用 Te ...

  2. 【深度学习框架】|TensorFlow|完成一个手写体识别任务

  3. Android TensorFlow Lite 深度学习识别手写数字mnist demo

    一. TensorFlow Lite TensorFlow Lite介绍.jpeg TensorFlow Lite特性.jpeg TensorFlow Lite使用.jpeg TensorFlow L ...

  4. TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:1~5

    原文:Mobile Deep Learning with TensorFlow Lite, ML Kit and Flutter 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[Apach ...

  5. flutter 人脸识别_使用flutter和tensorflow lite进行人脸识别认证

    flutter 人脸识别 The growth of processing power in devices and Machine learning allows us to create new ...

  6. TensorFlow Lite(实战系列一):TFLite Android 迁移训练构建自己的图像识别APP

    摘要 人工智能想要从实验室走向大众,一个必不可少的途径就是朝着智能终端.嵌入式产品等边缘设备发展.谷歌基于TFMobile推出了TFLite,我们只需要把训练好的模型按照一定规则转换成拥有.tflit ...

  7. python识别手写文字_Python3实现简单可学习的手写体识别(实例讲解)

    1.前言 版本:Python3.6.1 + PyQt5 + SQL Server 2012 以前一直觉得,机器学习.手写体识别这种程序都是很高大上很难的,直到偶然看到了这个视频,听了老师讲的思路后,瞬 ...

  8. 基于TensorFlow Lite实现的Android花卉识别应用

    介绍 本教程将在Android设备上使用TensorFlow Lite运行图像识别模型,具体包括: 使用TensorFlow Lite Model Maker训练自定义的图像分类器 利用Android ...

  9. 计算机视觉(十)——Tensorflow对Mnist手写体数据集做手写体识别

    博文主要内容 分析Mnist手写体数据集 实现手写体识别的原理和代码实现 分析Mnist数据集中一些歧义数据 实验中遇到的一些问题 分析Mnist手写体数据集 MNIST 数据集来自美国国家标准与技术 ...

  10. 基于TensorFlow卷积神经网络的手写体数字识别

    一.卷积神经网络(CNN) 二.LeNet 三.代码 1.Mnist手写体训练并测试 2.可视化 四.数据集分析 五.结果分析 1.准确率 2.可视化测试 一.卷积神经网络(CNN) 参考:https ...

最新文章

  1. Redis 笔记(04)— list类型(作为消息队列使用、在列表头部添加元素、尾部删除元素、查看列表长度、遍历指定列表区间元素、获取指定区间列表元素、阻塞式获取列表元素)
  2. mybatis动态sql中的trim标签的使用
  3. php.ini权限,php开启与关闭错误提示适用于没有修改php.ini的权限_PHP
  4. ds1302模块 树莓派_(16)给树莓派B+ 安装一个实时时钟芯片DS1302
  5. wxWidgets:wxLogFormatter类用法
  6. openwrt 替换Dropbear by openssh-server
  7. android生成aar无效,android studio生成aar包并在其他工程引用aar包的方法
  8. 你身边有没有白天上班,晚上打零工送外卖、跑滴滴、做代驾的朋友?你怎么看?
  9. 修复十一个重要高危漏洞 苹果致谢滴滴美研
  10. Hive(一)——基础操作
  11. 西华师范大学计算机专业保研资格,西华师范大学计算机学院 计算机应用技术保研条件...
  12. EPLAN学习笔记——常用操作步骤
  13. 化繁为简、敏捷迭代,轻量化小程序时代已然到来
  14. 案例——蚂蚁金服初探,唯一的金融互联网生态...
  15. Android连接WiFi再探索
  16. 亚马逊中国站获取全部商品分类
  17. 大数据主要学习什么?
  18. c语言循环结构程序设计教学,高级C语言循环结构程序设计教学教材演示幻灯片.ppt...
  19. iOS上二维码和一维码识别系列二
  20. 如何查询MOTO手机IMEI码和MSN码

热门文章

  1. 抖音死亡计算机在线测,抖音死亡计算器测试
  2. 编程视频资源教程汇总
  3. 我的个人知识管理工具软件
  4. 2021年T电梯修理考试题及T电梯修理模拟考试题
  5. R/BioC序列处理之四:BSgenome简介
  6. 微信小程序保存图片到相册
  7. 十年游戏建模师给想学次世代游戏建模同学的一些忠告,太受益了
  8. asp.net pdf如何转换成tif_如何将pdf转换成word?它可以解决大多数文档转换问题
  9. 测试用例设计——场景法
  10. 浑水摸「YY」、「侠盗」苹果和辛巴的「麦乳精」|极客一周