DL4J使用之手写数字识别

最近一直在学习深度学习,由于我是Java程序员出身,就选择了一个面向Java的深度学习库—DL4J。为了更加熟练的掌握这个库的使用,我使用该库,以MNIST(http://yann.lecun.com/exdb/mnist/)手写数字数据集作为基础,训练了一个模型,来识别手写字体。下面我们从以下几个方面讲解该项目的实现:

  • DL4J使用之手写数字识别

    • DL4J简介
    • 模型的训练
      • 训练数据集(MNIST)
      • 模型架构
    • 模型性能
    • 模型的保存与加载
    • 结果展示
    • 总结 与展望

DL4J简介

Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码全部公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上可以看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不仅仅在于它支持Java,更为重要的是它可以支持Spark。由于Deep Learning模型的训练需要大量的内存,而且原始数据的存储有时候也需要很大的外存空间,所以如果可以利用集群来处理便是最好不过了。当然,除了Deeplearning4j以外,还有一些Deep Learning的库可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我自己都没怎么用过,所以就不多说了,这里重点说说Deeplearning4j的使用。
从项目管理角度,DL4J官方给的例子中,推荐使用Maven构建项目,但是目前在学习阶段,我是直接从官网扣下来了需要的Jar包导入项目,这样有一个好处,在项目迁移到别的计算机上运行的时候不需要等待Maven下载jar包的时间。当然,工作中还是推荐使用Maven。不说了,下面是我提出来的Jar包:

看着还是挺庞大的,其实也难怪,毕竟深度学习需要大量的工作才能形成一个库。这些我已经上传到CSDN可以点击下方链接下载(https://download.csdn.net/download/yushengpeng/10286975)

模型的训练

训练数据集(MNIST)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。它有60000个训练样本集和10000个测试样本集。MNIST算是深度学习入门的一个数据集吧,也是一个比较优秀的手写数字数据集,可以用于半监督学习,并且取得了非常不错的成绩。下面是该数据集的部分截图:

关于如何将该数据集转换成DL4J能识别的格式,请学习DL4J的官方文档。我也上传了Dl4J的官方文档到了CSND,如果你有需求请前去下载(https://download.csdn.net/download/yushengpeng/10287018)。

模型架构

当我们正确读取数据后,我们需要定义具体的神经网络结构,这里我用的是Lenet,该网络是一个5层的神经网络(在深度学习中,我们约定俗成的认为输入层是第0层不参与层数统计),该网络各层情况如下:

第0层: nput layer: 输入数据为原始训练图像
第1层: Conv1:6个5*5的卷积核,步长Stride为1
第2层:Pooling1:卷积核size为2*2,步长Stride为2
第3层:Conv2:12个5*5的卷积核,步长Stride为1
第4层:Pooling2:卷积核size为2*2,步长Stride为2
第5层:Output layer:输出为10维向量

网络层级结构示意图如下:

Deeplearning4j的实现参考了官网(https://github.com/deeplearning4j/dl4j-examples)的例子。具体代码如下:

public class CNN_MNIST {private static Logger log = LoggerFactory.getLogger(CNN_MNIST.class);public static void main(String[] args) throws IOException {int nChannels = 1;int outputNum = 10; // The number of possible outcomesint batchSize = 64; // Test batch sizeint nEpochs = 2; // Number of training epochsint iterations = 1; // Number of training iterationsint seed = 123; //DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(.01).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5)// nIn and nOut specify depth. nIn here is the nChannels and// nOut is the number of filters to be applied.nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5)// Note that nIn need not be specified in later layers.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5,new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)) // See note below.backprop(true).pretrain(false).build();MultiLayerNetwork model = new MultiLayerNetwork(conf);model.init();model.setListeners(new ScoreIterationListener(1));for (int i = 0; i < nEpochs; i++) {model.fit(mnistTrain);log.info("*** Completed epoch {} ***", i);log.info("Evaluate model....");Evaluation eval = new Evaluation(outputNum);while (mnistTest.hasNext()) {DataSet ds = mnistTest.next();INDArray output = model.output(ds.getFeatureMatrix(), false);eval.eval(ds.getLabels(), output);}log.info(eval.stats());mnistTest.reset();log.info("****************Example finished********************");log.info("******SAVE TRAINED MODEL******");// Details// Where to save modelFile locationToSave = new File("trained_mnist_model.zip");// boolean save Updaterboolean saveUpdater = false;// ModelSerializer needs modelname, saveUpdater, LocationModelSerializer.writeModel(model, locationToSave, saveUpdater);}}
}

可以发现,神经网络需要定义很多的超参数,学习率、正则化系数、卷积核的大小、激励函数等都是需要人为设定的。不同的超参数,对结果的影响很大,其实后来发现,很多时间都花在数据处理和调参方面。毕竟自己设计网络的能力有限,一般都是参考大牛的论文,然后自己照葫芦画瓢地实现。这里实现的Lenet的结构是:卷积–>下采样–>卷积–>下采样–>全连接。和原论文的结构基本一致。卷积核的大小也是参考的原论文。具体细节可参考之前发的论文链接。这里我们设置了一个Score的监听事件,主要是可以在训练的时候获取每一次权重更新后损失函数的收敛情况,如下面所示:

模型性能

of classes: 10
Accuracy 0.9918
Precision 0.9917
Recall 0.9917
F1 Score 0.9917

模型性能还是不错的,在10000个手写数字测试集上的准确率能达到99.17%。当然,模型的好坏跟神经网络的架构,超参的设置都有关系,关于到底选用什么样的模型架构需要更多的经验,知识。一般具体问题具体分析。

模型的保存与加载

当我们训练好了一个模型的时候,我们需要将训练好的模型持久化到本地磁盘,或者其他存储介质。因为训练模型是一个非常耗时的工作,模型的大小,数据集的大小,训练一个模型需要一天,一周,一个月,甚至是更长的时间。我们不可能每次在实际的项目中,需要的时候再去训练出一个模型。DL4J也为我们实现了模型的持久化功能,具体代码如下:

File locationToSave = new File("trained_mnist_model.zip");//保存路径,存储位置
boolean saveUpdater = false;
ModelSerializer.writeModel(model, locationToSave, saveUpdater);

当然持久化模型是为了再次加载模型,使用模型。DL4J也为我们实现了模型的的加载功能,具体代码如下:

NativeImageLoader loader = new NativeImageLoader(28, 28, 1);
INDArray image = loader.asMatrix(new File("XXX://test.jpg"));//从本地磁盘中加载文件
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.transform(image);
INDArray output = model.output(image);//对图片进行分类预测

结果展示









总结 与展望

通过这个小项目,我参照官网手册,初步实现了LENET网络。并取得了不错的成果。当然,我也是学习了充足的理论之后,再来学习DL4J这个深度学习框架的。关于这个项目的源码,你可以去我的GutHub上下载:https://github.com/ShengPengYu/writtingRecoginition。
。该项目还有不足之处,比方说可以边测试边学习,我们在发现我们书写的测试数据分类不准确的时候,可以加入到训练数据库,在线对模型实时训练,因为每个用户的书写风格不一样,可能对分类结果有一定的影响。边测试边训练,可以训练出符合用户个人情况的模型。还有一种情况是。当然我也有一定的思考,比方说,如果我对目前模型进一步改进,做一个汉字识别项目,那么最后一层该使用什么架构,中国汉字那么多,如果使用one-hot模式,会不会维度太大,在时间复杂度和空间复杂度上是一个非常严峻的问题。

使用Dl4j训练的一个手写数字识别软件相关推荐

  1. 基于TensorFlow深度学习框架,运用python搭建LeNet-5卷积神经网络模型和mnist手写数字识别数据集,设计一个手写数字识别软件。

    本软件是基于TensorFlow深度学习框架,运用LeNet-5卷积神经网络模型和mnist手写数字识别数据集所设计的手写数字识别软件. 具体实现如下: 1.读入数据:运用TensorFlow深度学习 ...

  2. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  3. 深度学习笔记:01快速构建一个手写数字识别系统以及张量的概念

    深度学习笔记:01快速构建一个手写数字识别系统 神经网络代码最好运行在GPU中,但是对于初学者来说运行在GPU上成本太高了,所以先运行在CPU中,就是慢一些. 一.安装keras框架 使用管理员模式打 ...

  4. caffe学习(二):利用mnist数据集训练并进行手写数字识别(windows)

    准备数据集 http://yann.lecun.com/exdb/mnist/提供了训练集与测试集数据的下载. 但是caffe并不是直接处理这些数据,而是要转换成lmdb或leveldb格式进行读取. ...

  5. pyqt5手写板+pytorch卷积神经网络,实现手写数字识别软件

    卷积神经网络的结构 #定义网络结构 #不是le-net5的结构 class Net(nn.Module):def __init__(self):super(Net, self).__init__()# ...

  6. [附代码] 如何用HOG+SVM实现手写数字识别

    本文首发于微信公众号[DeepDriving],公众号后台回复关键字[手写数字识别]可获取本文代码链接. 前言 手写数字识别是机器学习和深度学习中一个非常著名的入门级图像识别项目,很多人都是从这个项目 ...

  7. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  8. 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)

    一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...

  9. pyTorch入门(六)——实战Android Minist OpenCV手写数字识别(附源码地址)

    学更好的别人, 做更好的自己. --<微卡智享> 本文长度为4239字,预计阅读12分钟 前言 前面几篇文章实现了pyTorch训练模型,然后在Windows平台用C++ OpenCV D ...

最新文章

  1. jq发送动态变量_山东体育学院康复生物力学团队发文,探索手机行为双任务对动态稳定性控制的影响...
  2. elasticsearch查询及logstash简介
  3. Java对MySQL数据库进行连接、查询和修改【转载】
  4. 使用idea 打jar包
  5. webkit内核浏览器的CSS写法
  6. nextpolish安装_希望组自主三代组装软件NextDenovo最新版本全球学术开源!
  7. deeplearning 源码收集
  8. http://www.guokr.com/blog/475765/
  9. 机器人学(机构学)笔记
  10. MySQL之字符串拼接
  11. bcnf分解算法_BCNF范式及其分解方法(对一次Lab作业的总结)
  12. mysql replace rpad_mysql中的replace,concat,cast等等字符串函数
  13. 【渝粤题库】广东开放大学 风险投资 形成性考核
  14. html乡愁代码,Homesick 乡愁
  15. OPPO手机怎么找到快应用入口
  16. swagger knife4j 解决接口下载文件响应乱码问题
  17. Extjs 百度地图扩展
  18. Java Date Time 教程-java.sql.Date
  19. win10 开始 开始全屏快捷键
  20. CDH平台DATANODE数据块阀值参数设置

热门文章

  1. 为什么说OKRS-E是适合的OKR框架
  2. web前端入门到实战:CSS颜色、背景和剪切
  3. C++中二维数组的动态创建于处理(zzl)
  4. kerberos中的spn详解
  5. 观点丨如何让劳动价值像资本一样自由流动与交易
  6. python中的英文歌_Python 爬网易音乐云歌曲,MV
  7. VIVO手机测试指令代码大全
  8. matlab 的 符号,MATLAB符号计算
  9. 二维码活码的诞生、技术原理及使用场景
  10. 北京大学肖臻老师《区块链技术与应用》ETH笔记 - 3.0 ETH数据结构篇