人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练
人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练
MXNet 是一个轻量级、可移植、灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 Apache 的孵化器项目。尽管现在已经有很多深度学习框架,包括 TensorFlow, Keras, Torch,以及 Caffe,但 Apache MXNet 因其对多 GPU 的分布式支持而越来越受欢迎。
环境准备
1.安装 Anaconda。Anaconda 是一个用于科学计算的 Python 发行版,提供了包管理与环境管理的功能。Anaconda 利用 conda 来进行 package 和 environment 的管理,并且已经包含了 Python 和相关的配套工具。
Anaconda3-4.4下载地址: https://repo.continuum.io/archive/Anaconda3-4.4.0-Windows-x86_64.exe
2.在 conda 下安装 pip,安装命令为‘conda install pip’
3.安装 OpenCV-python 库。OpenCV-python 是一个很强大的计算机视觉库,在这个项目中可以用于处理图像。使用‘pip install openvc-python’在 Anaconda 环境下安装 OpenCV。也可以从源文件进行编译(注意:conda 安装 opencv3.0 不能运行)。
4.安装 scikit learn,一个开源的 python 机器学习科学计算库,它将用于对数据进行预处理。安装命令为‘conda install scikit-learn’。
5.安装 Jupyter Notebook,安装命令为‘conda install jupyter notebook’。
6.安装 MXNet。安装命令为‘pip install mxnet’。
------------------
数据库
使用的数据库是德国交通标志识别基准,来自论文《德国交通标志识别基准:多类别分类竞赛》( J. Stallkamp, M. Schlipsing, J. Salmen, and C. Igel. "The German Traffic Sign Recognition Benchmark: A multi-class classification competition." ),发表在 IEEE International Joint Conference on Neural Networks,2011。该数据集包含 39209 张训练样例和 12630 张测试样例,有 43 种不同的交通标志——停车标志,限速标志,各种警示标志以及其他标志。
数据库中的每张图像大小为 32×32,均为三通道彩色图。每幅图属于一种交通标志。图像种类标签由 0 到 42 的整数表示。
从一个 NumPy 阵列中下载数据,数据分为训练,验证和测试集。训练集包含 39209 张大小为 32×32,通道数为 3 的图像,所以 NumPy 阵列的维度为 39209×32×32×3。该项目中作者仅使用了训练集和验证集。作者将使用网上的真实图像来测试所构建的模型。X_train 存储图像,维度为 39209×32×32×3。Y_train 存储图像对应的类标,维度为 39209,包含 0-42 的整数,对应每张图的类标。
训练过程
1. 准备数据集
X_train 和 Y_train 组成了训练数据集。可以使用 scikit-learn 对训练数据集进行分割得到验证集,这样可以避免使用出现过的图片测试模型。代码如下:
2. 训练数据预处理
批训练
神经网络训练需要花费大量时间和内存。所以作者将数据分批训练,一批大小为 64. 不仅是为了让数据适应内存,而且它可以让 MXNet 尽量利用 GPU 的计算效率。
归一化
除此之外,图像的像素值也进行了归一化,可以使学习算法更快收敛。下面是对训练数据进行预处理的代码:
3. 构建深度网络
目前,对于图像识别这类处在探索研究热点的问题,学界已经设计了很多效果良好的网络结构。所以最好的方法是实现一个已经发表出来的网络结构,然后对其进行改进。基于 AlexNet 结构,构建了一个简化版的卷积神经网络。AlexNet 是 2012 年发表的一个经典网络,在当年取得了 ImageNet 的最好成绩。
网络共有 8 层,其中前 5 层是卷积层,后边 3 层是全连接层,在每一个卷积层中包含了激励函数 RELU 以及局部响应归一化(LRN)处理,然后再经过池化(max pooling),最后的一个全连接层的输出是具有 1000 个输出的 softmax 层,最后的优化目标是最大化平均的多元逻辑回归。
在此之后也有很多更优秀的网络结构被提出,例如 VGGNet 和 ResNet,大家可以选择更好的网络结构去实现。
由于 MXNet 的符号计算构架,该神经网络的代码十分简洁明了
4. 训练网络
训练 epoch 为 10,训练好的模型存在 JSON 文件中,并且可以通过测量训练和验证准确率来观测网络“学习”的情况。
5. 载入预训练模型
下面给出了加载第 10 个 epoch 模型(最终模型)的代码。由于将在单张图片上进行测试,所以批尺寸由 64 减到 1,数据维度也变成了 1×3×32×32。
测试过程
测试图像(32×32×3)样例:
从结果可以看出可能性最高的种类为停车标志,说明预测准确。如果需要对模型有一个更完整的衡量,还需要用测试数据库进行测试,得到最终的分类准确率。
总结
本文我们介绍了使用 MXNet 进行多目标分类任务的方法。使用 MXNet,在 AlexNet 的结构基础上构建了一个更为简单的卷积神经网络结构。网络由卷积层,激活函数层,池化层和全连接层组成,采用德国交通标志图像训练数据库对该网络进行训练,实验结果证明网络可以将交通标志进行正确的分类。介绍了如何使用 MXNet 对数据进行预处理,构建网络,以及如何加载预训练好的网络模型。可以看出,MXNet 因其在多 GPU 上进行并行训练的能力,以及网络模型构建简单灵活的特性,是一个十分优秀的深度学习框架。
==========================
本人微信公众帐号: 心禅道(xinchandao)
本人微信公众帐号:双色球预测合买(ssqyuce)
人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练相关推荐
- 深度学习之基于Inception_ResNet_V2和CNN实现交通标志识别
这次的结果是没有想到的,利用官方的Inception_ResNet_V2模型识别效果差到爆,应该是博主自己的问题,但是不知道哪儿出错了. 本次实验分别基于自己搭建的Inception_ResNet_V ...
- 深度学习100例 | 第3天:交通标志识别 - PyTorch实现
文章目录 一.导入数据 1. 获取类别名 2. 数据可视化 3. 加载数据文件 4. 划分数据 二.自建模型 三.模型训练 1. 优化器与损失函数 2. 模型的训练 四.结果分析 大家好,我是K同学啊 ...
- Ubuntu为julia安装深度学习框架MXNet(支持CUDA和OPenCV编译)
Ubuntu为julia安装深度学习框架MXNet(支持CUDA和OPenCV编译) 环境介绍与注意事项 下载源文件 安装依赖 编译 环境配置 安装MXNet 测试 后记 环境介绍与注意事项 Ubun ...
- 深度学习 Day 15——利用卷神经网络实现好莱坞明星识别
深度学习 Day 15--利用卷神经网络实现好莱坞明星识别 文章目录 深度学习 Day 15--利用卷神经网络实现好莱坞明星识别 一.前言 二.我的环境 三.前期工作 1.导入依赖项并设置GPU 2. ...
- matlab交通标志神经网络识别,基于神经网络的交通标志识别方法
Municipal & Traffic Construction SCIENCE & TECHNOLOGY FOR DEVELOPMENT 149 基于神经网络的交通标志识别方法 赵丹 ...
- 【交通标志识别】基于BP神经网络实现交通标志识别matlab代码
1 简介 近年来,交通标志识别在车辆视觉导航系统中是一个热门研究课题.为了安全驾驶和高效运输,交通部门在公路道路上设置了各类重要的交通标志,以提醒司机和行人有关道路交通信息,如指示标志.警告标志.禁止 ...
- java深度学习框架Deeplearning4j实战(一)BP网络分类器
1.Deeplearning4j 深度学习,人工智能今天已经成了IT界最流行的词,而tensorflow,phython又是研究深度学习神经网络的热门工具.tensorflow是google的出品,而 ...
- 深度学习框架Deeplearning4j实战:文本智能抽取快速定位
一.Deeplearning4j Deeplearning4j(简称DL4J)是基于java的一个深度学习框架,已经发布了1.0版本的beta版. 与其他深度学习框架相比,DL4J具有以下优点: 与S ...
- 深度学习框架zf_谈谈深度学习框架的数据排布
最近同事碰到了一个不同框架模型互相转换的问题:pytorch模型或caffe模型要转到tensorflow和TFLite上进行移动端的部署.模型从pytorch或caffe转tensorflow通过O ...
最新文章
- dfs-Rank the Languages
- 12306的变态验证码算得了什么?我有Python神器!
- Web前端开发笔记——第二章 HTML语言 第十一节 语义标签
- C++自定义自适应中值滤波
- 前端学习(1050):todolist正在进行个数和已完成个数
- ogg oracle 测试kafka_基于OGG的Oracle与Hadoop集群/kafka准实时同步
- 解密五种AI筛选的“新冠”新药:能靶向病毒细胞侵入的蛋白酶
- vue3.0项目服务器部署
- Tensorflow(r1.4)API--tf.summary.scalar
- is 32-bit instead of 64-bit 亲测可用
- Linux命令整理-Kali
- java jvm参数获取_在java代码中获取JVM参数
- Java毕业设计-企业员工考勤打卡管理系统
- cada0图纸尺寸_a0图纸尺寸
- 最详细的js获取当前url的方法
- Linux系统小说源码网站,Linux系统小说源码网站
- JAVAWEB开发Myeclipse 项目中报“无法解析类型 java.io.ObjectInputStream,从必需的 .class 文件间接引用了它”解决办法
- 关于RxJava2.0你不知道的事
- 职场中哪些职场很重要?
- 新支点国产服务器操作系统与虚拟化平台和云管理平台实现兼容
热门文章
- 【使用递归玩通关汉诺塔游戏】算法01-递归(斐波那契数列、汉罗塔问题)-java实现
- 学习笔记(十七)——redis(CRUD)
- logback无法生成日志文件之谜
- oracle中forall in,oracle10g的forall功能加强
- 微信小程序-canvas绘制文字实现自动换行
- 路面平整度采集和计算方法
- java数组二分查找_java 13-1 数组高级二分查找
- 【自动驾驶】欧拉角和旋转矩阵之间的转换
- 【数学与算法】奇异矩阵、奇异值、奇异值分解、奇异性
- caffe windows 学习第一步:编译和安装(vs2012+win 64)