前言

神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。

关注微信号“逻辑编程"来获取本书的更多信息。

这一章节我们将会解决一个真正的问题:手写字体识别。我们将识别像下面图中这样的手写数字。

在开始之前,我们先要准备好相应的测试数据。我们不能像前边那样简单的产生手写字体,毕竟我们自己还不知道如何写出一个产生手写字体的算法。训练要达到一定的精度需要较多的训练数据。还好,前人栽树后人乘凉,先驱们已经收集了宝贵的训练材料。MNIST就是一个广泛使用的数据集。不但可以拿来用,我们还可以从网站上看到别人的识别准确率。这样我们就有了很好的参照。MNIST包含一套训练数据和一套测试数据,分别来自不同的人群的手写。

MNIST网站: http://yann.lecun.com/exdb/mnist/

这个数据集是写在特定的二进制文件中的,并非普通图片格式。每个图片数据由28*28个像素组成。每个像素1个字节表示颜色灰度级。MNIST网站上有具体的介绍。

我们写一个类来完成数据集的读取工作,并提供接口返回指定的训练或者测试数据。具体代码不做分析,仅将代码附在下面,供读者使用。代码执行前要先下载数据文件并保留GZIP格式。代码执行后将随机抽取20个生成PNG图片供读者自己查看和验证数据内容。

下面我们写个测试类来识别手写字体。我们使用MNIST库的60000训练数据来反复训练我们的神经网络。每轮训练后使用MNIST库的10000个测试数据来测试识别率。

下面是代码:

package com.luoxq.ann;

import java.util.Arrays;

import java.util.Random;

public class MnistTest {

public static void main(String... args) {

int[] shape = {28 * 28, 10};

NeuralNetwork nn = new NeuralNetwork(shape);

Mnist mnist = new Mnist();

mnist.load();

mnist.shuffle();

System.out.println("Shape: " + Arrays.toString(shape));

System.out.println("Initial correct rate: " + test(nn, mnist));

int epochs = 1000;

double rate = 0.5;

System.out.println("Learning rate: " + rate);

System.out.println("Epoch,Time,Correctness\n----------------------");

long time = System.currentTimeMillis();

Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);

for (int epoch = 1; epoch <= epochs; epoch++) {

for (int sample = 0; sample < data.length; sample++) {

nn.train(data[sample].input, data[sample].output, rate);

}

long seconds = (System.currentTimeMillis() - time) / 1000;

System.out.println(epoch + ", " + seconds + ", " +

test(nn, mnist));

}

}

private static int test(NeuralNetwork nn, Mnist mnist) {

int correct = 0;

Mnist.Data[] data = mnist.getTestSlice(0, 10000);

for (int sample = 0; sample < data.length; sample++) {

if (max(nn.f(data[sample].input)) == data[sample].label) {

correct++;

}

}

return correct;

}

private static int max(double[] d) {

double max = d[0];

int idx = 0;

for (int i = 1; i < d.length; i++) {

if (max < d[i]) {

max = d[i];

idx = i;

}

}

return idx;

}

}

我们先用一个10个神经元的单层神经网络试试看。结果出乎意外的好。我们很快就获得了超过90%的正确率。单层网络几乎就是对每个数字的像素分布做简单统计。能获得如此高的识别率,还是很神奇的。 在达到90%之后再训练已经效果不大,达到饱和了。我们必须换一种方法来做了。

Shape: [784, 10]

Initial correct rate: 1373

Learning rate: 0.5

Epoch,Time,Correctness

----------------------

1, 4, 6429

2, 8, 7663

3, 13, 8963

4, 17, 9029

5, 22, 9016

6, 27, 9062

7, 31, 9063

8, 36, 9066

9, 41, 9072

10, 45, 9057

11, 50, 9084

12, 55, 9072

13, 61, 9062

14, 66, 9050

15, 70, 9077

16, 75, 9052

17, 79, 9068

18, 84, 9055

19, 88, 9060

20, 93, 9064

那么我们来使用三层神经网络试一试。在试了几个不同的中间层大小和学习率参数之后,我找到了下面这个较好的参数组合:

Shape: [784, 50, 10]

Initial correct rate: 944

Learning rate: 1.0

Epoch,Time,Correctness

----------------------

1, 24, 7459

2, 59, 9232

3, 99, 9313

4, 131, 9379

5, 153, 9412

6, 176, 9443

7, 200, 9412

8, 226, 9447

9, 248, 9462

10, 269, 9461

11, 290, 9465

12, 314, 9493

13, 343, 9477

14, 368, 9499

15, 392, 9502

16, 420, 9509

17, 447, 9482

18, 472, 9508

19, 496, 9491

20, 518, 9536

21, 545, 9523

22, 569, 9549

23, 593, 9527

24, 618, 9527

25, 643, 9520

26, 667, 9513

27, 689, 9507

28, 712, 9527

29, 734, 9501

30, 758, 9521

31, 781, 9508

32, 804, 9534

33, 827, 9534

34, 850, 9550

35, 875, 9569

我们很快达到了95%以上的正确率。可见多层网络相对单层神经网络还是有优势的。虽然这个正确率还达不到产品水平,但是作为初次尝试结果还是很不错的。

下面是MNIST文件读取源代码:

package com.luoxq.ann;

import javax.imageio.ImageIO;

import java.awt.image.BufferedImage;

import java.io.DataInputStream;

import java.io.File;

import java.io.FileInputStream;

import java.util.Random;

import java.util.zip.GZIPInputStream;

/**

* Created by luoxq on 17/4/15.

*/

public class Mnist {

static class Data {

public byte[] data;

public int label;

public double[] input;

public double[] output;

}

public static void main(String... args) throws Exception {

Mnist mnist = new Mnist();

mnist.load();

System.out.println("Data loaded.");

Random rand = new Random(System.nanoTime());

for (int i = 0; i < 20; i++) {

int idx = rand.nextInt(60000);

Data d = mnist.getTrainingData(idx);

BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);

for (int x = 0; x < 28; x++) {

for (int y = 0; y < 28; y++) {

img.setRGB(x, y, toRgb(d.data[y * 28 + x]));

}

}

File output = new File(i + "_" + d.label + ".png");

if (!output.exists()) {

output.createNewFile();

}

ImageIO.write(img, "png", output);

}

}

static int toRgb(byte bb) {

int b = (255 - (0xff & bb));

return (b << 16 | b << 8 | b) & 0xffffff;

}

Data[] trainingSet;

Data[] testSet;

public void shuffle() {

Random rand = new Random();

for (int i = 0; i < trainingSet.length; i++) {

int x = rand.nextInt(trainingSet.length);

Data d = trainingSet[i];

trainingSet[i] = trainingSet[x];

trainingSet[x] = trainingSet[i];

}

}

public Data getTrainingData(int idx) {

return trainingSet[idx];

}

public Data[] getTrainingSlice(int start, int count) {

Data[] ret = new Data[count];

System.arraycopy(trainingSet, start, ret, 0, count);

return ret;

}

public Data getTestData(int idx) {

return testSet[idx];

}

public Data[] getTestSlice(int start, int count) {

Data[] ret = new Data[count];

System.arraycopy(testSet, start, ret, 0, count);

return ret;

}

public void load() {

trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");

testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");

if (trainingSet.length != 60000 || testSet.length != 10000) {

throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);

}

}

private Data[] load(String imgFile, String labelFile) {

byte[][] images = loadImages(imgFile);

byte[] labels = loadLabels(labelFile);

if (images.length != labels.length) {

throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);

}

int len = images.length;

Data[] data = new Data[len];

for (int i = 0; i < len; i++) {

data[i] = new Data();

data[i].data = images[i];

data[i].label = 0xff & labels[i];

data[i].input = dataToInput(images[i]);

data[i].output = labelToOutput(labels[i]);

}

return data;

}

private double[] labelToOutput(byte label) {

double[] o = new double[10];

o[label] = 1;

return o;

}

private double[] dataToInput(byte[] b) {

double[] d = new double[b.length];

for (int i = 0; i < b.length; i++) {

d[i] = (b[i] & 0xff) / 255.0;

}

return d;

}

private byte[][] loadImages(String imgFile) {

try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {

int magic = in.readInt();

if (magic != 0x00000803) {

throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));

}

int count = in.readInt();

int rows = in.readInt();

int cols = in.readInt();

if (rows != 28 || cols != 28) {

throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);

}

byte[][] data = new byte[count][rows * cols];

for (int i = 0; i < count; i++) {

in.readFully(data[i]);

}

return data;

} catch (Exception ex) {

throw new RuntimeException("Failed to read file: " + imgFile, ex);

}

}

private byte[] loadLabels(String labelFile) {

try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {

int magic = in.readInt();

if (magic != 0x00000801) {

throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));

}

int count = in.readInt();

byte[] data = new byte[count];

in.readFully(data);

return data;

} catch (Exception ex) {

throw new RuntimeException("Failed to read file: " + labelFile, ex);

}

}

}

欢迎关注订阅号逻辑编程内容。

java识别手写文字_神经网络入门 第6章 识别手写字体相关推荐

  1. python识别手写文字_如何快速使用Python神经网络识别手写字符?(文末福利)

    原标题:如何快速使用Python神经网络识别手写字符?(文末福利) 点击标题下[异步社区]可快速关注 在本文中,我们将进一步探讨一些使用Python神经网络识别手写字符非常有趣的想法.如果只是想了解神 ...

  2. python识别手写文字_使用 python 获取 CASIA 脱机和在线手写汉字库

    在申请书中介绍了数据集的基本情况: > CASIA-HWDB 和 CASIA-OLHWDB 数据库由中科院自动化研究所在 2007-2010 年间收集, 均各自包含 1,020 人书写的脱机(联 ...

  3. 动态背景 图层上写文字_文字效果很难吗?教你如何打造绚丽的浮雕文字!

    今天继续为大家分享优质教程 在平面设计中,我们可以给字体加上各种各样的效果,不同的效果展现出不同的文字的个性.而在众多效果中,浮雕效果最能表现出具有层次感的视觉冲击力.其独特的魅力,能让整个画面增色不 ...

  4. 华为p50pro会不会搭载鸿蒙系统,华为p50pro有没有手写笔_华为p50pro会不会用手写笔...

    华为即将上市的华为p50pro这款手机现在已经获得了非常多的热度,那么这款手机它是不是用手写笔的呢?接下来我们就一起来了解一下华为p50pro它是不是配备了手写笔吧. 1.华为p50pro有没有手写笔 ...

  5. 怎么识别图片中的文字?不妨试试这几个识别工具

    大家平时在完成老师或者领导布置的任务的时候,不免需要上网查找资料,有时候我们看到某段文字很契合自己的内容,就会将这段文字复制下来,转换为自己的语言,那要是文字无法复制该怎么办呢?其实大家也不用过于担心 ...

  6. python车牌识别使用训练集_基于Python 实现的车牌识别项目

    车牌识别在高速公路中有着广泛的应用,比如我们常见的电子收费(ETC)系统和交通违章车辆的检测,除此之外像小区或地下车库门禁也会用到,基本上凡是需要对车辆进行身份检测的地方都会用到. 简介 车牌识别系统 ...

  7. 用c语言写代码_如何避免用动态语言的思维写Go代码

    由于招聘市场上Go工程师的供给量不足,所以在招人的时候我们招了不少愿意转型用Go语言进行开发的PHP工程师,不过虽说换了个语言,在他们代码的时候还是能发现很多PHP的影子.if语句后面非要带括号这种问 ...

  8. JAVA项目代码手写吗_一个老程序员是如何手写Spring MVC的

    见人爱的Spring已然不仅仅只是一个框架了.如今,Spring已然成为了一个生态.但深入了解Spring的却寥寥无几.这里,我带大家一起来看看,我是如何手写Spring的.我将结合对Spring十多 ...

  9. python 卷积神经网络猫狗大战_卷积神经网络入门(1) 识别猫狗

    按照我的理解,CNN的核心其实就是卷积核的作用,只要明白了这个问题,其余的就都是数学坑了(当然,相比较而言之后的数学坑更难). 如果学过数字图像处理,对于卷积核的作用应该不陌生,比如你做一个最简单的方 ...

最新文章

  1. C++多线程的简单程序
  2. php文件下载脚本,PHP文件下载实例代码浅析
  3. C#中ArrayList的简单使用
  4. SQL增删改查,基础
  5. 理解并实施:GLBP(ccna200-120新增考点)
  6. Unity手游之路四3d旋转-四元数,欧拉角和变幻矩阵
  7. android horizontalscrollview 动画,Android HorizontalScrollView左右滑动效果
  8. android 单机斗地主,单机斗地主
  9. 如何确认访客所在的国家
  10. java.lang.IllegalArgumentException: Request header is too large 解决方案
  11. 爆炸性环境设备通用要求标准_防爆电气设备的适用环境及温度要求
  12. Android—Socket服务端与客户端用字符串的方式互相传递图片
  13. Android开发时的多点触控是如何实现的?
  14. Vue2.0进阶组件篇2 解析饿了么(spinner组件)
  15. 幸运抽奖java_java10幸运抽奖
  16. FatMouse believes that the fatter a mouse is, the faster it runs.
  17. Linux删除其中一行的快捷键,Linux 命令快捷键
  18. 巴菲特指标:估值过高
  19. 大数据培训:Hadoop生态系统圈
  20. 拓嘉启远:拼多多购物运输中的商品能拒收吗

热门文章

  1. 爬虫项目:获取movie
  2. 如何彻底关闭系统还原功能和删除系统还原点
  3. c++中的ignore和tie
  4. 如何有效的使用搜索词
  5. ZOJ 3380 Patchouli's Spell Cards(DP,大数)
  6. 【飞鱼科技】2022届春季校园招聘火热进行中
  7. matlab ccd采集,CCD数据采集.doc
  8. 【自然语言处理】gensim的word2vec
  9. GBase 8s HAC高可用方案
  10. KITTI数据集学习笔记