点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

在本文中,我们描述了我们如何使用卷积神经网络 (CNN) 来估计花卉图像中关键点的位置,并且在 3D 模型上渲染这些图像上茎和花的位置等关键点。

为了能够与真实花束的照片对比,所创建的图像必须尽可能逼真。这是通过使用从多个角度拍摄的真实花朵照片并将它们渲染在 3D 模型上来实现的。对于每一朵新花,他们都会从 7 个不同的角度拍摄照片。在照相亭中,花朵由电机自动旋转。

相比之下,图片的后期处理还没有完全自动化。目前数据库中有数千种鲜花,每天都会添加新的鲜花。将此乘以角度数,将获得大量要手动处理的图片。后处理步骤之一是定位 3D 模型所需的图像上的几个关键点,最重要的是茎位和花顶位置。

数据集

在数据集中,成千上万的图像已经手动标注了关键点,所以我们有大量的训练数据可以使用。

以上是训练数据集中的一些带注释的花,它从几个不同的角度展示了同一朵花。茎位置为蓝色,花顶部位置为绿色。在一些图片中,茎的起源被花本身隐藏了。在这种情况下,我们需要“有根据的猜测”最有可能在哪里。

网络模型

因为模型必须输出一个数字而不是一个类,所以我们实际上是在做回归。CNN 以分类任务而闻名,但在回归方面也表现良好。例如,DensePose使用基于 CNN 的方法进行人体姿势估计。

网络从几个标准卷积块开始。这些块由3个卷积层组成,然后是最大池、批量标准化层和退出层。

  • 所述卷积层含有多个滤波器。每个过滤器就像一个模式识别器。下一个卷积块有更多的过滤器,所以它可以在模式中找到模式。

  • 最大池化会降低图像的分辨率。这限制了模型中的参数数量。通常,对于图像分类,我们对某个对象在图像中的位置不感兴趣,只要它在那里即可。在我们的例子中,我们对位置感兴趣。尽管如此,拥有几个最大池化层并不会影响性能。

  • 批量标准化层有助于模型更快地训练(收敛)。在一些深度网络中,没有它们,训练完全失败。

  • 退出层将随机禁用节点,这将防止过度拟合模型。

在卷积块之后,我们将张量展平,使其与密集层兼容。全局最大池化或平均最大池化也将实现平坦张量,但会丢失所有空间信息。扁平化在我们的实验中效果更好,尽管它的(计算)成本是拥有更多模型参数导致更长的训练时间。

在两个带有Relu激活的密集隐藏层之后是输出层,我们想要预测2 个关键点的x和y坐标,所以我们需要在输出层有 4 个节点。图像可以有不同的分辨率,因此我们将坐标缩放到 0 到 1 之间,并在使用前将它们放大。输出层没有激活函数。即使目标变量在 0 和 1 之间,这对我们来说也比使用sigmoid效果更好。作为参考,以下是我们使用的 Python 深度学习库Keras的完整模型摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 126, 126, 64)      2368
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 124, 124, 64)      36928
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 122, 122, 64)      36928
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 61, 61, 64)        0
_________________________________________________________________
batch_normalization_1 (Batch (None, 61, 61, 64)        256
_________________________________________________________________
dropout_1 (Dropout)          (None, 61, 61, 64)        0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 59, 59, 128)       73856
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 57, 57, 128)       147584
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 55, 55, 128)       147584
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 27, 27, 128)       0
_________________________________________________________________
batch_normalization_2 (Batch (None, 27, 27, 128)       512
_________________________________________________________________
dropout_2 (Dropout)          (None, 27, 27, 128)       0
_________________________________________________________________
flatten_1 (Flatten)          (None, 93312)             0
_________________________________________________________________
dense_1 (Dense)              (None, 256)               23888128
_________________________________________________________________
batch_normalization_3 (Batch (None, 256)               1024
_________________________________________________________________
dropout_3 (Dropout)          (None, 256)               0
_________________________________________________________________
dense_2 (Dense)              (None, 256)               65792
_________________________________________________________________
batch_normalization_4 (Batch (None, 256)               1024
_________________________________________________________________
dropout_4 (Dropout)          (None, 256)               0
_________________________________________________________________
dense_3 (Dense)              (None, 4)                 1028
=================================================================
Total params: 24,403,012
Trainable params: 24,401,604
Non-trainable params: 1,408
_________________________________________________________________

你们可能会问:为什么是 3 个卷积层?或者为什么是 2 个卷积块?我们在超参数搜索中将这些数字作为超参数包括在内。连同诸如密集层数、退出层、批量标准化和卷积滤波器数量之类的参数,我们进行了随机搜索以找到超参数的最佳组合。

对于训练,我们使用学习率为的Adam 优化器0.005。当验证损失在几个时期内没有改善时,学习率会自动降低。作为损失函数,我们使用均方误差 (MSE)。因此,大错误比小错误受到的惩罚相对更多。

训练和效果

这些是训练 50 个时期后的损失(误差)图:

大约 8 个 epoch 后,验证损失变得高于训练损失。直到训练结束,验证损失仍然减少,因此我们没有看到模型严重过度拟合的迹象。测试集上的最终损失 (MSE) 为0.0064. MSE 的解释可能非常不直观。

MAE 是——这意味着预测平均降低 1.7%

白色圆圈包含目标关键点,实心圆圈包含我们的预测。在大多数情况下,它们非常接近(重叠)。

改进

我们有一些改进的想法,但我们还没有时间实施:

  1. 目前,单个模型正在估计两个关键点。为每个关键点训练一个特定的模型可能会更好。这还有一个额外的好处,可以稍后添加新的关键点,而无需重新训练完整的模型。

  2. 另一个想法是考虑照片的角度。例如,将其添加为密集层的输入,可能会争辩说,角度会改变任务的性质,因此提供此信息可能有助于网络。按照这种思路,为每个角度训练一个单独的网络也可能是有益的。

结论

通过这项研究,我们证明了使用 CNN 检测花卉图像中的关键点的可行性。所使用的方法也可能适用于其他领域的后处理任务,例如产品摄影。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

基于深度学习的花卉图像关键点检测相关推荐

  1. 基于深度学习的2D图像目标检测

    参见第一部分网址1,第二部分网址2 目前学术和工业界出现的目标检测算法分成3类:(参见一文读懂目标检测:R-CNN.Fast R-CNN.Faster R-CNN.YOLO.SSD) 1. 传统的目标 ...

  2. 目标检测YOLO实战应用案例100讲-基于深度学习的航拍图像YOLOv5目标检测研究及应用(论文篇)

    目录 基于深度学习的航拍图像目标检测研究 航拍图像目标检测 评价指标与数据集

  3. 基于深度学习的花卉检测与识别系统(YOLOv5清新界面版,Python代码)

    摘要:基于深度学习的花卉检测与识别系统用于常见花卉识别计数,智能检测花卉种类并记录和保存结果,对各种花卉检测结果可视化,更加方便准确辨认花卉.本文详细介绍花卉检测与识别系统,在介绍算法原理的同时,给出 ...

  4. 基于深度学习的口罩识别与检测PyTorch实现

    基于深度学习的口罩识别与检测PyTorch实现 1. 设计思路 1.1 两阶段检测器:先检测人脸,然后将人脸进行分类,戴口罩与不戴口罩. 1.2 一阶段检测器:直接训练口罩检测器,训练样本为人脸的标注 ...

  5. 深度学习在遥感图像目标检测中的应用综述

    深度学习在遥感图像目标检测中的应用综述 1 人工智能发展 1.1 发展历程 1.2 深度学习的应用 2 深度学习 2.1 机器学习概述 2.2 神经网络模型 2.3 深度学习 2.4 深度学习主要模型 ...

  6. 基于深度学习的日志数据异常检测

    基于深度学习的日志数据异常检测 数据对象 智能运维(AIOps)是通过机器学习等算法分析来自于多种运维工具和设备的大规模数据.智能运维的分析数据对象多源运维数据包括系统运行时数据和历史记录数据,历史记 ...

  7. 基于深度学习的高精度家禽猪检测识别系统(PyTorch+Pyside6+YOLOv5模型)

    摘要:基于深度学习的高精度家禽猪检测识别系统可用于日常生活中或野外来检测与定位家禽猪目标,利用深度学习算法可实现图片.视频.摄像头等方式的家禽猪目标检测识别,另外支持结果可视化与图片或视频检测结果的导 ...

  8. 第四篇:基于深度学习的人脸特征点检测 - 数据预处理

    在上一篇博文中,我们整理了300-W.LFPW.HELEN.AFW.IBUG和300-VW这6个数据集,使用Python将特征点绘制在对应的图片上,人工验证了数据集的正确性,最终获得了223034个人 ...

  9. 基于深度学习的高精度牙齿健康检测识别系统(PyTorch+Pyside6+YOLOv5模型)

    摘要:基于深度学习的高精度牙齿健康检测识别系统可用于日常生活中检测牙齿健康状况,利用深度学习算法可实现图片.视频.摄像头等方式的牙齿目标检测识别,另外支持结果可视化与图片或视频检测结果的导出.本系统采 ...

最新文章

  1. 搬箱轮滑再炫技!一个被波士顿动力耽误的机器人
  2. STARTTLS在电子邮件环境中的安全性分析
  3. 头条Android面试题,史上最全的Android面试题集锦(五)
  4. 利用Arduino Nano 对于另外的Arduino控制板下载Bootloader
  5. BZOJ3173:[TJOI2013]最长上升子序列(Splay)
  6. “谈谈我对技术发展的一点感悟”阅读小记
  7. VTK:可视化之TextureMapPlane
  8. leetcode 220. Contains Duplicate III | 220. 存在重复元素 III (Treeset解法+分桶解法)
  9. 使用application log 分析navigation target解析错误
  10. 【DS】时间复杂度排序
  11. 服务器微信了早上好,每天早上好的问候语 微信早安问候语合集66句
  12. DPDK Qos之报文处理流水线
  13. BZOJ4448:[SCO2015]情报传递
  14. 泵车砼活塞故障预警-冠军方案
  15. kaggle:Costa Rican Household Poverty Level Prediction(1)DEA
  16. python 抓取微博评论破亿_利用python实现爬取微博评论的方法
  17. 有道词典【输入式翻页】
  18. 友盟统计 H5 vue 隐藏友盟图标
  19. 大数据日志分析Hadoop项目实战
  20. Mac精品应用推荐:专业的后期特效制作软件

热门文章

  1. 华为全球最快AI训练集群Atlas 900诞生
  2. 性能提升3倍的树莓派4,被爆设计缺陷!
  3. 上海居民被垃圾分类逼疯!这款垃圾自动分类器也许能帮上忙
  4. 不止临床应用,AI还要帮不懂编程的医生搞科研
  5. 机器学习开源项目Top10
  6. 李彦宏说自动驾驶比人更安全,还认为中国用户更愿意放弃隐私
  7. 一款零注解侵入的 API 文档生成工具,你用过吗?
  8. 架构设计的本质:系统与子系统、模块与组件、框架与架构
  9. 用了 3 年 Apollo,最后我选择了 Nacos,原因不多说了
  10. IntelliJ IDEA 更新后,电脑卡成球,该如何优化?