人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练

MXNet 是一个轻量级、可移植、灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 Apache 的孵化器项目。尽管现在已经有很多深度学习框架,包括 TensorFlow, Keras, Torch,以及 Caffe,但 Apache MXNet 因其对多 GPU 的分布式支持而越来越受欢迎。

环境准备
1.安装 Anaconda。Anaconda 是一个用于科学计算的 Python 发行版,提供了包管理与环境管理的功能。Anaconda 利用 conda 来进行 package 和 environment 的管理,并且已经包含了 Python 和相关的配套工具。

Anaconda3-4.4下载地址: https://repo.continuum.io/archive/Anaconda3-4.4.0-Windows-x86_64.exe

2.在 conda 下安装 pip,安装命令为‘conda install pip’

3.安装 OpenCV-python 库。OpenCV-python 是一个很强大的计算机视觉库,在这个项目中可以用于处理图像。使用‘pip install openvc-python’在 Anaconda 环境下安装 OpenCV。也可以从源文件进行编译(注意:conda 安装 opencv3.0 不能运行)。

4.安装 scikit learn,一个开源的 python 机器学习科学计算库,它将用于对数据进行预处理。安装命令为‘conda install scikit-learn’。

5.安装 Jupyter Notebook,安装命令为‘conda install jupyter notebook’。

6.安装 MXNet。安装命令为‘pip install mxnet’。

------------------
数据库

使用的数据库是德国交通标志识别基准,来自论文《德国交通标志识别基准:多类别分类竞赛》( J. Stallkamp, M. Schlipsing, J. Salmen, and C. Igel. "The German Traffic Sign Recognition Benchmark: A multi-class classification competition." ),发表在 IEEE International Joint Conference on Neural Networks,2011。该数据集包含 39209 张训练样例和 12630 张测试样例,有 43 种不同的交通标志——停车标志,限速标志,各种警示标志以及其他标志。
数据库中的每张图像大小为 32×32,均为三通道彩色图。每幅图属于一种交通标志。图像种类标签由 0 到 42 的整数表示。

从一个 NumPy 阵列中下载数据,数据分为训练,验证和测试集。训练集包含 39209 张大小为 32×32,通道数为 3 的图像,所以 NumPy 阵列的维度为 39209×32×32×3。该项目中作者仅使用了训练集和验证集。作者将使用网上的真实图像来测试所构建的模型。X_train 存储图像,维度为 39209×32×32×3。Y_train 存储图像对应的类标,维度为 39209,包含 0-42 的整数,对应每张图的类标。

训练过程

1. 准备数据集
X_train 和 Y_train 组成了训练数据集。可以使用 scikit-learn 对训练数据集进行分割得到验证集,这样可以避免使用出现过的图片测试模型。代码如下:

2. 训练数据预处理
批训练
神经网络训练需要花费大量时间和内存。所以作者将数据分批训练,一批大小为 64. 不仅是为了让数据适应内存,而且它可以让 MXNet 尽量利用 GPU 的计算效率。
归一化
除此之外,图像的像素值也进行了归一化,可以使学习算法更快收敛。下面是对训练数据进行预处理的代码:

3. 构建深度网络
目前,对于图像识别这类处在探索研究热点的问题,学界已经设计了很多效果良好的网络结构。所以最好的方法是实现一个已经发表出来的网络结构,然后对其进行改进。基于 AlexNet 结构,构建了一个简化版的卷积神经网络。AlexNet 是 2012 年发表的一个经典网络,在当年取得了 ImageNet 的最好成绩。

网络共有 8 层,其中前 5 层是卷积层,后边 3 层是全连接层,在每一个卷积层中包含了激励函数 RELU 以及局部响应归一化(LRN)处理,然后再经过池化(max pooling),最后的一个全连接层的输出是具有 1000 个输出的 softmax 层,最后的优化目标是最大化平均的多元逻辑回归。
在此之后也有很多更优秀的网络结构被提出,例如 VGGNet 和 ResNet,大家可以选择更好的网络结构去实现。
由于 MXNet 的符号计算构架,该神经网络的代码十分简洁明了

4. 训练网络
训练 epoch 为 10,训练好的模型存在 JSON 文件中,并且可以通过测量训练和验证准确率来观测网络“学习”的情况。

5. 载入预训练模型
下面给出了加载第 10 个 epoch 模型(最终模型)的代码。由于将在单张图片上进行测试,所以批尺寸由 64 减到 1,数据维度也变成了 1×3×32×32。

测试过程
测试图像(32×32×3)样例:

从结果可以看出可能性最高的种类为停车标志,说明预测准确。如果需要对模型有一个更完整的衡量,还需要用测试数据库进行测试,得到最终的分类准确率。

总结
本文我们介绍了使用 MXNet 进行多目标分类任务的方法。使用 MXNet,在 AlexNet 的结构基础上构建了一个更为简单的卷积神经网络结构。网络由卷积层,激活函数层,池化层和全连接层组成,采用德国交通标志图像训练数据库对该网络进行训练,实验结果证明网络可以将交通标志进行正确的分类。介绍了如何使用 MXNet 对数据进行预处理,构建网络,以及如何加载预训练好的网络模型。可以看出,MXNet 因其在多 GPU 上进行并行训练的能力,以及网络模型构建简单灵活的特性,是一个十分优秀的深度学习框架。

==========================

本人微信公众帐号: 心禅道(xinchandao)

本人微信公众帐号:双色球预测合买(ssqyuce)

人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练相关推荐

  1. 深度学习之基于Inception_ResNet_V2和CNN实现交通标志识别

    这次的结果是没有想到的,利用官方的Inception_ResNet_V2模型识别效果差到爆,应该是博主自己的问题,但是不知道哪儿出错了. 本次实验分别基于自己搭建的Inception_ResNet_V ...

  2. 深度学习100例 | 第3天:交通标志识别 - PyTorch实现

    文章目录 一.导入数据 1. 获取类别名 2. 数据可视化 3. 加载数据文件 4. 划分数据 二.自建模型 三.模型训练 1. 优化器与损失函数 2. 模型的训练 四.结果分析 大家好,我是K同学啊 ...

  3. Ubuntu为julia安装深度学习框架MXNet(支持CUDA和OPenCV编译)

    Ubuntu为julia安装深度学习框架MXNet(支持CUDA和OPenCV编译) 环境介绍与注意事项 下载源文件 安装依赖 编译 环境配置 安装MXNet 测试 后记 环境介绍与注意事项 Ubun ...

  4. 深度学习 Day 15——利用卷神经网络实现好莱坞明星识别

    深度学习 Day 15--利用卷神经网络实现好莱坞明星识别 文章目录 深度学习 Day 15--利用卷神经网络实现好莱坞明星识别 一.前言 二.我的环境 三.前期工作 1.导入依赖项并设置GPU 2. ...

  5. matlab交通标志神经网络识别,基于神经网络的交通标志识别方法

    Municipal & Traffic Construction SCIENCE & TECHNOLOGY FOR DEVELOPMENT 149 基于神经网络的交通标志识别方法 赵丹 ...

  6. ​【交通标志识别】基于BP神经网络实现交通标志识别matlab代码

    1 简介 近年来,交通标志识别在车辆视觉导航系统中是一个热门研究课题.为了安全驾驶和高效运输,交通部门在公路道路上设置了各类重要的交通标志,以提醒司机和行人有关道路交通信息,如指示标志.警告标志.禁止 ...

  7. java深度学习框架Deeplearning4j实战(一)BP网络分类器

    1.Deeplearning4j 深度学习,人工智能今天已经成了IT界最流行的词,而tensorflow,phython又是研究深度学习神经网络的热门工具.tensorflow是google的出品,而 ...

  8. 深度学习框架Deeplearning4j实战:文本智能抽取快速定位

    一.Deeplearning4j Deeplearning4j(简称DL4J)是基于java的一个深度学习框架,已经发布了1.0版本的beta版. 与其他深度学习框架相比,DL4J具有以下优点: 与S ...

  9. 深度学习框架zf_谈谈深度学习框架的数据排布

    最近同事碰到了一个不同框架模型互相转换的问题:pytorch模型或caffe模型要转到tensorflow和TFLite上进行移动端的部署.模型从pytorch或caffe转tensorflow通过O ...

最新文章

  1. dfs-Rank the Languages
  2. 12306的变态验证码算得了什么?我有Python神器!
  3. Web前端开发笔记——第二章 HTML语言 第十一节 语义标签
  4. C++自定义自适应中值滤波
  5. 前端学习(1050):todolist正在进行个数和已完成个数
  6. ogg oracle 测试kafka_基于OGG的Oracle与Hadoop集群/kafka准实时同步
  7. 解密五种AI筛选的“新冠”新药:能靶向病毒细胞侵入的蛋白酶
  8. vue3.0项目服务器部署
  9. Tensorflow(r1.4)API--tf.summary.scalar
  10. is 32-bit instead of 64-bit 亲测可用
  11. Linux命令整理-Kali
  12. java jvm参数获取_在java代码中获取JVM参数
  13. Java毕业设计-企业员工考勤打卡管理系统
  14. cada0图纸尺寸_a0图纸尺寸
  15. 最详细的js获取当前url的方法
  16. Linux系统小说源码网站,Linux系统小说源码网站
  17. JAVAWEB开发Myeclipse 项目中报“无法解析类型 java.io.ObjectInputStream,从必需的 .class 文件间接引用了它”解决办法
  18. 关于RxJava2.0你不知道的事
  19. 职场中哪些职场很重要?
  20. 新支点国产服务器操作系统与虚拟化平台和云管理平台实现兼容

热门文章

  1. 【使用递归玩通关汉诺塔游戏】算法01-递归(斐波那契数列、汉罗塔问题)-java实现
  2. 学习笔记(十七)——redis(CRUD)
  3. logback无法生成日志文件之谜
  4. oracle中forall in,oracle10g的forall功能加强
  5. 微信小程序-canvas绘制文字实现自动换行
  6. 路面平整度采集和计算方法
  7. java数组二分查找_java 13-1 数组高级二分查找
  8. 【自动驾驶】欧拉角和旋转矩阵之间的转换
  9. 【数学与算法】奇异矩阵、奇异值、奇异值分解、奇异性
  10. caffe windows 学习第一步:编译和安装(vs2012+win 64)