每个人似乎都对胶囊网络(CapsNet)这种新的神经网络架构的出现很兴奋,我也不例外,忍不住用胶囊网络来建立一个路侧交通标志的识别系统,这篇文章就是对这一过程的介绍,当然,也包括胶囊网络的一些基本概念阐述。

项目使用TensorFlow开发,是基于Sara Sabour,Nicholas Frosst和Geoffrey E. Hinton的论文《 胶囊间动态路由 》,代码保存在github。如果你迫不及待想要试试Tensorflow等机器学习框架,可以访问汇智网的Python机器学习在线环境。

卷积神经网络有什么问题?

卷积神经网络(CNN)的问题部分源于其对图像感知的泛化能力,例如一个训练好的图像识别网络可能会对同一图像的旋转版本识别错误,这就是为什么在训练时经常使用数据增强和平均/最大池(Average / Max Pooling)。 池化通过随机选择下一层的神经元子集建立一个新的层。 这可以有效降低上层的计算需求,而且也使得网络减小对特征出现的原始位置的依赖性。 这一简化的依据在于:我们假设特征出现的确切位置对目标识别而言影响不大。

和CNN一样,上层的胶囊可以覆盖更大的图像区域,但是与最大化池不同,我们不会丢弃该区域内目标物体的准确位置信息。

这使得模型对图像中的细微变化可以保持不变的输出。 另一方面,模型有可能忽视图像发生的位移变化。不变意味着无论检测到的字符的顺序和位置是否改变,网络的输出总是相同的。 因此该模型能够理解图像中的特征的旋转和位移,并产生适当的输出。 这对于使用池化来说是不可能的。 这就是启发我们发明这个新架构的原因。

胶囊网络

胶囊网络赋予了模型理解图像中所发生变化的能力,从而可以更好地概括所感知的内容。 要了解这个架构如何运作,重要的是掌握胶囊的概念。

胶囊是一组神经元,其激活向量表示某种特定类型的实体(如对象或对象部分)的实例化参数。

我们习惯从深度角度来谈论深度学习,而胶囊网络则引入了嵌套的概念,嵌套为深度引入了一个新的维度。 不是采用添加层的方法来增加网络的深度,相反,胶囊网络是在另一个层中添加(多个)新的图层。 这有点抽象, 但是当你仔细观察时,就会发现情况并不是那么复杂。 在论文中,这一方法的核心分为两部分: 基础胶囊和数字胶囊 。 在我们的案例中,后一部分将被重命名为交通标志胶囊 。

胶囊和基础胶囊

这一层基于经典的卷积计算,创建一个新的由N * C滤波器组成的卷积层。 N表示滤波器的数量,C表示每个胶囊的尺寸。 因此会创建出具有(T,T)大小的N * C个新图像。 在上图中,每个胶囊的值在新创建的图像中以红色显示。 Tensorflow代码如下:

conv = tf.contrib.layers.conv2d(input_layer, N * C, kernel, stride, padding="VALID")
# Shape: (?, T, T, N * C)

现在创建好了卷积操作,我们可以重排这些卷积以便创建胶囊操作:

capsules = tf.reshape(conv, shape=(-1, T*T*N, C, 1))
# Shape: (?, T*T*N, C, 1)# conv[0][0][0][:C] <=> capsules[0]

然后我们得到T * T * N个大小为C的胶囊(在这个项目中是1152个胶囊)。 需要指出的是,卷积的第一个C值(见代码中的注释)等于第一个胶囊的值,正如上面代码所示。 最后,原论文中给出了一个新的非线性函数,可以单独应用于每个胶囊。 这个新函数被称为挤压 (Squashing),看起来像这样:

因此,我们使用非线性Squashing函数来确保将短矢量的长度压缩到接近零,而长矢量的长度压缩到略低于1

交通标志胶囊

在本项目中,这个层由43个胶囊组成,每个胶囊代表一种特定的交通标志。 为了确定模型的预测结果,我们可以选择具有最大长度的胶囊。 但在此之前,需要在前一层的1152个胶囊之间进行转换。 这将通过路由的方法完成。 该方法的作用是选中前一层的哪些胶囊与输出层胶囊进行关联。 换句话说,对于每个胶囊,会有一个新的神经网络进行判断:“嘿,这个胶囊对这个类的判别有价值吗?”

使用迭代的路由处理过程,每个活动胶囊将在上面的层中选择一个胶囊,作为它在树中的父节点。

在路由中,对特征的选择不再是像池化那样随意。 在这篇文章中,我不会详细介绍路由所使用的确切公式,论文中有这些公式的描述。 本项目的实现代码在我的github 。 我还在继续改进以使算法更具可扩展性。 对于交通标志胶囊和路由,我在实现中尽量遵循了论文中的数学公式。

图像重建

这种方法有助于引导网络将胶囊向量视为实际的物体,允许在重建之前对每个图像进行编码。 这在正则化方面也得到了很好的结果。

我们使用额外的重建损失来鼓励数字胶囊对输入数字的实例化参数进行编码。

这部分实现代码也包含在项目的github中,代码中的图像重建实现,使用了卷积和最近邻算法来放大图像。 事实上,我不能只是创建一堆简单的层,因为要重建的图像包含3个输出通道。 尽管在MNIST数据中这个实现表现得相当好,但我对其在大规模解决方案中的有效性还存有一些怀疑,不过这只是我的个人观点。

因此,模型最终的损失是基于两种可选的损失:

  • 边际损失:基于模型的实际预测。 这是最高标准的胶囊。
  • 重建损失:基于图像之间平方差的解码器损失的平均值。

模型架构

由于我处理的数据集与原论文不同,所以模型架构也做了一些调整。

第一个卷积使用256个滤波器、大小为9的核(VALID填充)、RELU激活,dropout取值0.7。

基础胶囊层包含16个滤波器、大小为5的核、16个胶囊。最终获得256个(10,10)大小的滤波器。 即1600个16值胶囊。 最后一层(交通标志胶囊)由大小为32的43个胶囊(43个类)组成。

上述结构的构建代码如下:

def _build_main_network(self, images, conv_2_dropout):"""This method is used to create the two convolutions and the CapsNet on the top**input:*images: Image PLaceholder*conv_2_dropout: Dropout value placeholder**return: ***Caps1: Output of first Capsule layer*Caps2: Output of second Capsule layer"""# First BLock:# Layer 1: Convolution.shape = (self.h.conv_1_size, self.h.conv_1_size, 3, self.h.conv_1_nb)conv1 = self._create_conv(self.tf_images, shape, relu=True, max_pooling=False, padding='VALID')# Layer 2: Convolution.shape = (self.h.conv_2_size, self.h.conv_2_size, self.h.conv_1_nb, self.h.conv_2_nb)conv2 = self._create_conv(conv1, shape, relu=True, max_pooling=False, padding='VALID')conv2 = tf.nn.dropout(conv2, keep_prob=conv_2_dropout)# Create the first capsules layercaps1 = conv_caps_layer(input_layer=conv2,capsules_size=self.h.caps_1_vec_len,nb_filters=self.h.caps_1_nb_filter,kernel=self.h.caps_1_size)# Create the second capsules layer used to predict the outputcaps2 = fully_connected_caps_layer(input_layer=caps1,capsules_size=self.h.caps_2_vec_len,nb_capsules=self.NB_LABELS,iterations=self.h.routing_steps)return caps1, caps2

训练

训练时我使用了Keras的ImageDataGenerator以便进行数据增强。

结果(准确度):

  • 训练:99%
  • 验证:98%
  • 测试:97%

这个结果没能达到经典的卷积神经网络的最佳效果。 但是,考虑到我大部分时间都是在实现胶囊网络,而不是花在超参数调整和图像处理方面,因此对我来说,97%算是初次尝试的好成绩。 我现在还在努力提高这个指标。

分类示例

如果你喜欢这篇文章,请关注我的头条号:新缸中之脑!

原文:Understand and apply CapsNet on Traffic sign classification

机器学习实战:用胶囊网络识别交通标志相关推荐

  1. 人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练

    人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练 MXNet 是一个轻量级.可移植.灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 ...

  2. 使用卷积神经网络识别交通标志

    什么是卷积神经网络 以下解释来源于ujjwalkarn的技术博客: 卷积神经网络(ConvNets 或者 CNNs)属于神经网络的范畴,在图像识别和分类领域具有高效的能力.卷积神经网络可以成功识别人脸 ...

  3. 使用TensorFlow识别交通标志

    作者:chen_h 微信号 & QQ:862251340 微信公众号:coderpai 这篇博客是翻译Waleed Abdulla写的使用TensorFlow识别交通标志,作者已经授权翻译,这 ...

  4. 机器学习实验大作业:Yolov5s交通标志检测

    设计内容与要求 1.自行搭建Yolov5s网络完成交通标志检测和识别任务,训练和验证数据集采用课程QQ群里发布的"交通标志检测数据集CCTSDB",测试集采用"交通标志检 ...

  5. 实例39:用胶囊网络识别黑白图中的服装图片

    一.熟悉样本:了解Fashion-MNIST数据集 FashionMNIST数据集的单个样本为28pixel*28pixel的灰度图片.训练集有60000张图片,测试集有10000张图片.样本内容为上 ...

  6. ☀️机器学习实战☀️基于 YOLO网络 的人脸识别 |(文末送机器学习书籍~)

  7. 机器学习实战:用网络摄像头预测年龄和性别

    你有没有猜过一个人的年龄? 下面这个简单的神经网络模型可以帮你做这件事. 本文的演示将从网络摄像头中获取实时视频流,并自动标注其中出现人脸的年龄和性别. 在家门口放一个这样的摄像头就可以了解访客的年龄 ...

  8. 交通标志识别论文综述

    交通标志识别是计算机视觉领域的一个研究热点.主要研究方向是使用机器学习和图像处理技术来识别交通标志. 近年来,随着深度学习技术的发展,交通标志识别的研究取得了显著进展.许多研究人员提出了基于卷积神经网 ...

  9. 基于BP 网络分类器的交通标志识别

    基于BP 网络分类器的交通标志识别 摘要:针对中国全部 3 大类 116 个交通标志,即禁令标志.指示标志.警告标志,用 BP 网络实现分类功能. 实验中使用了 3 种测试集,即加高斯噪声.水平扭曲和 ...

  10. Keras深度学习实战——交通标志识别

    Keras深度学习实战--交通标志识别 0. 前言 1. 数据集与模型分析 1.1 数据集介绍 1.2 模型分析 2. 交通标志识别 2.1 数据集加载与预处理 2.2 模型构建与训练 相关链接 0. ...

最新文章

  1. ubuntu 关闭qq打不开的终极方法
  2. java1.8 类库_Commons Configuration 1.8发布 配置管理Java类库
  3. MySQL的if,case语句使用总结
  4. scikit-learn 入门
  5. oracle exp 二进制,Oracle备份之exp自动逻辑备份(二)
  6. c++直角坐标系与极坐标系的转换_平面向量的奇技淫巧——斜坐标系的一系列低级研究...
  7. 看我如何利用教科书级别的释放后使用漏洞(CVE-2020-6449)
  8. Java Web前后端分离架构
  9. CodeSmith激活教程
  10. dev、test和staging、prod是什么意思?
  11. 关于双模键盘的模式转化
  12. 新版标准日本语中级_第八课
  13. 重新学javaweb---cookiesession
  14. Scala+HuffmanCoding实现无损压缩
  15. python储物柜难题_7个储物柜收纳小技巧,轻松解决你的收纳难题。
  16. 简述扁平式管理的技术手段
  17. 高德地图定位失败_常见问题
  18. word文档合并,书签丢失
  19. 林业工程抗旱造林技术
  20. 上拉电阻 下拉电阻 拉电流 灌电流

热门文章

  1. ubuntu16.04 安装 NVIDIA 显卡驱动 +cuda9.0+cudnn +tensorflow AND问题若干
  2. 安装telnet macOS High Sierra 10.13
  3. 关于如何处理MyEclipse中struts2与Hiber 3中antlr-2.7.2.jar与antlr-2.7.6包冲突的问题
  4. 基于eclipse和hiber的pojo、数据库表与mapping的相互转换
  5. java毕业设计汽车客运站票务管理系统源码+lw文档+mybatis+系统+mysql数据库+调试
  6. 离散题目9(判断是否为单射函数)
  7. 人工智能在广告行业的应用
  8. rabbitmq 连接报错 An unexpected connection driver error occured
  9. 计算机网络常见面试题目
  10. 实现 EC20 4G模块PPP拨号上网