一、概述

上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。

样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:

二、源码

全部代码:

namespace MulticlassClassification_Mnist
{class Program{//Assets files download from:https://gitee.com/seabluescn/ML_Assetsstatic readonly string AssetsFolder = @"D:\StepByStep\Blogs\ML_Assets\MNIST";static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "train_tags.tsv");static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "train");static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");static void Main(string[] args){MLContext mlContext = new MLContext(seed: 1);TrainAndSaveModel(mlContext);TestSomePredictions(mlContext);Console.WriteLine("Hit any key to finish the app");Console.ReadKey();}public static void TrainAndSaveModel(MLContext mlContext){// STEP 1: 准备数据var fulldata = mlContext.Data.LoadFromTextFile<InputData>(path: TrainTagsPath, separatorChar: '\t', hasHeader: false);var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);var trainData = trainTestData.TrainSet;var testData = trainTestData.TestSet;// STEP 2: 配置数据处理管道        var dataProcessPipeline = mlContext.Transforms.CustomMapping(new LoadImageConversion().GetMapping(), contractName: "LoadImageConversionAction").Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)).Append(mlContext.Transforms.NormalizeMeanVariance( outputColumnName: "FeaturesNormalizedByMeanVar", inputColumnName: "ImagePixels"));// STEP 3: 配置训练算法 (using a maximum entropy classification model trained with the L-BFGS method)var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "Label", featureColumnName: "FeaturesNormalizedByMeanVar");var trainingPipeline = dataProcessPipeline.Append(trainer).Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictNumber", "Label"));// STEP 4: 训练模型使其与数据集拟合           ITransformer trainedModel = trainingPipeline.Fit(trainData);          // STEP 5:评估模型的准确性           var predictions = trainedModel.Transform(testData);var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Label", scoreColumnName: "Score");PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);// STEP 6:保存模型
            mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);           }private static void TestSomePredictions(MLContext mlContext){// Load Model           ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);// Create prediction engine var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));           foreach(var image in TestFolder.GetFiles()){count++;InputData img = new InputData(){FileName = image.Name};var result = predEngine.Predict(img);Console.WriteLine($"Current Source={img.FileName},PredictResult={result.GetPredictResult()}");                }}       }class InputData{[LoadColumn(0)]public string FileName;[LoadColumn(1)]public string Number;[LoadColumn(1)]public float Serial;       }class OutPutData : InputData{public float[] Score;public int GetPredictResult(){float max = 0;int index = 0;for (int i = 0; i < Score.Length; i++){if (Score[i] > max){max = Score[i];index = i;}}return index;}       }
}

View Code

三、分析

整个处理流程和上一篇文章基本一致,这里解释两个不一样的地方。

1、自定义的图片读取处理通道

namespace MulticlassClassification_Mnist
{public class LoadImageConversionInput{public string  FileName { get; set; }}public class LoadImageConversionOutput{[VectorType(400)]public float[] ImagePixels { get; set; }public string ImagePath;}[CustomMappingFactoryAttribute("LoadImageConversionAction")]public class LoadImageConversion : CustomMappingFactory<LoadImageConversionInput, LoadImageConversionOutput>{       static readonly string TrainDataFolder = @"D:\StepByStep\Blogs\ML_Assets\MNIST\train";public void CustomAction(LoadImageConversionInput input, LoadImageConversionOutput output){  string ImagePath = Path.Combine(TrainDataFolder, input.FileName);output.ImagePath = ImagePath;Bitmap bmp = Image.FromFile(ImagePath) as Bitmap;           output.ImagePixels = new float[400];for (int x = 0; x < 20; x++)for (int y = 0; y < 20; y++){var pixel = bmp.GetPixel(x, y);var gray = (pixel.R + pixel.G + pixel.B) / 3 / 16;output.ImagePixels[x + y * 20] = gray;}           bmp.Dispose();                     }public override Action<LoadImageConversionInput, LoadImageConversionOutput> GetMapping()=> CustomAction;}
}

这里可以看出,我们自定义的数据处理通道,输入为文件名称,输出是一个float数组,这里数组必须要指定宽度,由于图片分辨率为20*20,所以数组宽度指定为400,输出ImagePath为文件详细地址,用来调试使用,没有实际用途。处理思路非常简单,遍历每个Pixel,计算其灰度值,为了减少工作量我们把灰度值进行缩小,除以了16 ,由于后面数据会做归一化,所以这里影响不是太明显。

2、模型测试

            DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));int count = 0;int success = 0;foreach(var image in TestFolder.GetFiles()){count++;InputData img = new InputData(){FileName = image.Name};var result = predEngine.Predict(img);if(int.Parse(image.Name.Substring(0,1))==result.GetPredictResult()){success++;}                }

我们把测试目录里的全面图片读出遍历了一遍,将其测试结果和实际结果做了一次验证,实际上是把评估(Evaluate)的事情又重复做了一次,两次测试的成功率基本接近。

四、关于图片特征提取

我们是采用图片所有像素的灰度值来作为特征值的,但必须要强调的是:像素值矩阵不是图片的典型特征。虽然有时候对于较规则的图片,通过像素提取方式进行计算,也可以取得很好的效果,但在处理稍微复杂一点的图片的时候,就不管用了,原因很明显,我们人类在分析图片内容时看到的特征更多是线条等信息,绝对不是像素值,看下图:

我们人类很容易就判断出这两个图片表达的是同一件事情,但其像素值特征却相差甚远。

传统的图片特征提取方式很多,比如:SIFT、HOG、LBP、Haar等。 现在采用TensorFlow的模型进行特征提取效果非常好。下一篇文章介绍图片分类时再进行详细介绍。

五、资源获取

源码下载地址:https://github.com/seabluescn/Study_ML.NET

工程名称:MulticlassClassification_Mnist_Useful

MNIST资源获取:https://gitee.com/seabluescn/ML_Assets

点击查看机器学习框架ML.NET学习笔记系列文章目录

转载于:https://www.cnblogs.com/seabluescn/p/10942116.html

机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)相关推荐

  1. 深度学习笔记:07神经网络之手写数字识别的经典实现

    神经网络之手写数字识别的经典实现 上一节完成了简单神经网络代码的实现,下面我们将进行最终的实现:输入一张手写图片后,网络输出该图片对应的数字.由于网络需要用0-9一共十个数字中挑选出一个,所以我们的网 ...

  2. 机器学习框架ML.NET学习笔记【1】基本概念与系列文章目录

    一.序言 微软的机器学习框架于2018年5月出了0.1版本,2019年5月发布1.0版本.期间各版本之间差异(包括命名空间.方法等)还是比较大的,随着1.0版发布,应该是趋于稳定了.之前在园子里也看到 ...

  3. 深度学习面试题12:LeNet(手写数字识别)

    目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...

  4. Python学习记录 搭建BP神经网络实现手写数字识别

    搭建BP神经网络实现手写数字识别 通过之前的文章我们知道了,构建一个简单的神经网络需要以下步骤 准备数据 初始化假设 输入神经网络进行计算 输出运行结果 这次,我们来通过sklearn的手写数字数据集 ...

  5. 深度学习框架Caffe学习笔记(6)-测试自己的手写数字图片

    在之前的实验中我们使用过 $ ./build/tools/caffe.bin test \ -model examples/mnist/lenet_train_test.prototxt \ -wei ...

  6. 深度学习--TensorFlow(项目)Keras手写数字识别

    目录 效果展示 基础理论 1.softmax激活函数 2.神经网络 3.隐藏层及神经元最佳数量 一.数据准备 1.载入数据集 2.数据处理 2-1.归一化 2-2.独热编码 二.神经网络拟合 1.搭建 ...

  7. 北京大学曹健——Tensorflow笔记 05 MNIST数据集输出手写数字识别准确率

              # 前向传播:描述了网络结构 minist_forward.py # 反向传播:描述了模型参数的优化方法 mnist_backward.py # 测试输出准确率minist_tes ...

  8. 深度学习第一周 tensorflow实现mnist手写数字识别

  9. 2.7mnist手写数字识别之训练调试与优化精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)

    2.7mnist手写数字识别之训练调试与优化精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列) 目录 2.7mnist手写数字识别之训练调试与优化精讲(百度架构师手把手带你零基础实践深度学习原 ...

最新文章

  1. Python 开发者节省时间的 10 个小技巧
  2. 802.11 区分广播 多播 单播帧
  3. 首届“开悟AI+游戏高校大赛”启动
  4. import cv2 失败 ImportError:DLL load fail:找不到指定模块
  5. linux 火狐浏览器插件,Linux系统Firefox(火狐浏览器)插件挂机
  6. 2016二级c语言成绩查询,2016年12月计算机二级C语言测试及答案
  7. centos下yum安装wget失败
  8. 《摩尔神话》:硅基经济的扫地僧戈登·摩尔
  9. Android应用方法数查看,查看size是否超过65k
  10. android 修复工具,牛学长安卓手机修复工具(安卓手机修复助手)V2.4.0.11 免费版
  11. android图片添加文字,android图片上添加文字
  12. 南信大学生怎样看知网,看外文文献
  13. 爬取noi官网所有题目分析
  14. linux普通用户密码到期修改为原密码方法
  15. Primeng CascadeSelect UI显示BUG解决方案
  16. 11 编程指南_流数据
  17. 当前时间显示器(代码屏显)
  18. Linux中xxd的简单应用
  19. XXXXXXXX学校“新教师、新风采”展示课活动方案
  20. Pre-Trained_Models_Past_Present_and_Future

热门文章

  1. cba篮球暂停次数和时间_为什么足球赛的观赏性比篮球更强?这三点是主要原因...
  2. 用spss做多组两两相关性分析_卡方检验的事后两两比较
  3. html 父元素右下角,html – 如何在父元素和父元素的兄弟元素上显示子元素?
  4. 飞狐的日线 java_JAVA 版 ATX-Client
  5. Python出入库简洁系统
  6. button上传替换file上传按钮,并显示图片缩略图,纯jsp操作
  7. Linux:常用shell快捷键
  8. 为什么QQ浏览器不是默认浏览器但是在打开网页的时候还是默认启动?
  9. node 安装express提示不是内部或外部命令
  10. Python:使用threading模块实现多线程编程三[threading.Thread类的重要函数]