来自法国 Capgemini Invent 公司的高级数据科学家 Ahmed BESBES 三个月前参加了一个其公司内部的比赛:使用机器学习方法帮助海洋科学家更好的识别鲸鱼,根据鲸尾页突的外观作为主要特征。

比赛要求对于每一幅测试图像模型要给出最相似的前20幅图像,这不是一个简单的分类任务而是相似检索任务。

最终,Ahmed获得了第三名,他把在此过程中搭建模型的全过程详细分享了出来,并附上了其在各个步骤使用的开源工具,相信对那些想要使用深度学习解决实际问题的朋友肯定有帮助。

01

问题定义和选择正确的损失函数

第一种方法:分类

简单的做法是使用分类分数作为相似性排序的依据,但很不幸的是,这样做难以得到好的相似性排序结果。

这提醒我们,如果要体现个体之间的相似性,必须在训练阶段对样本进行显式学习和排序。

第二种方法:度量学习

学习有效的嵌入特征用于样本相似度比较和排序属于度量学习的范畴。

好在这是一个成熟的技术领域,对于想要入门的朋友可以看看:

https://gombru.github.io/2019/04/03/ranking_loss/

https://omoindrot.github.io/triplet-loss

作者使用了两种损失函数:

  • Triplet loss

  • ArcFace loss

Ⅰ、Triplet loss来自谷歌2015年论文 FaceNet。

使用Triplet loss 时作者又加入了如下训练tricks:

  1. 硬采样,即triplet (a, p, n) 满足不等式 d(a, n) < d(a, p)

  2. PK采样,保证每一个batch来自P个不同的类,每一类K幅图像

  3. 在线生成triplets

了解执行细节可以查看作者代码:

https://github.com/ahmedbesbes/whales-classification/tree/master

学习更多这些技术推荐阅读论文:

https://arxiv.org/pdf/1703.07737.pdf

Ⅱ、ArcFace loss 来自CVPR 2019。

论文中该算法在人脸识别问题中打败了 triplet loss, intra-loss, 和 inter-loss等。

ArcFace loss  相比 triplet loss有更好的特性:

  1. 对于类别很多的问题也表现的很好;

  2. 消除了训练triplet loss中的难样本挖掘的问题;

  3. 提供了漂亮的几何解释;

  4. 能够稳定训练;

  5. 收敛更快;

  6. 更重要的,实验发现仅需要一个ArcFace loss训练的单模型打败了5个triplet loss训练的模型的融合。

02

小心研究数据

功夫再高也怕菜刀!模型再好,数据质量不高也白搭!

作者花费了大量心思在数据上:

  • 去除噪声和已经损坏的图像,比如分辨率很低的,或者鲸鱼的尾巴根本看不见的图像。

  • 去除那些只含有一幅图像的类别,这被证明非常有效。因为度量学习需要类别内部的上下文信息,所以只含有一幅图像的类别是信息不足的。

  • 检测并提取鲸尾页突的图像,以去除大海、水花等的干扰,这一步扮演了注意力机制的角色。作者标注了大约300幅鲸尾页突的图像,训练了一个YOLOv3检测器,使用的标注工具来自:https://www.makesense.ai/ ,YOLOv3的训练代码来自:https://github.com/ultralytics/yolov3

学习要点:我们应该在合适且干净的数据上赢得更多的精度提升,而不是努力做花里胡哨的模型。

03

不要忽视迁移学习

作者最开始使用ImageNet数据集上的预训练模型(renset34, densenet121等)作为骨干网,后来发现了一个类似比赛 Kaggle Humpback Whale Identification ,尽管两个比赛中鲸鱼属于不同种,但将ImageNet预训练模型在Humpback数据集上微调后获得了巨大的精度提升!使作者的方案一下跳到了前几名,而且收敛更快(减少了30%的epochs)。

学习要点:

  • 迁移学习往往是有效的,且很少会带来伤害,如果可以请把ImageNet预训练模型在与你的问题相似的数据集上进行微调是很有必要的;

  • 迁移学习是增加训练样本的间接方式。

04

输入图像分辨率影响很大

在该问题中因为拍摄的设备很专业,很多图像很大,可以达到3000x1200像素或者更大。

作者首先使用224 x 224分辨率的图像,后来改成大一些分辨率的图像,得到了明显的精度提升,最后发现 480x480是对这个问题最好的输入大小。

学习要点:

  • 如果你在处理高分辨率图像,尝试较大的分辨率是不错的选择,大分辨率能够让模型学到特别小的细节特征,有助于样本个体间的区分;

  • 并不是越大越好,过大会使得训练收敛更慢,且如果原来数据中图像较小,resize到过大分辨率会带来精度下降,因为原有信号被破坏了。

05

网络结构的选择

业界有很多流行的深度学习架构如 VGG 或 ResNet,还有很多新出的复杂架构 ResNet-Inception-V4 、 NASNet等。到底如何选择呢?

作者经过三个月实验,作者总结出如下结论:

学习要点:

  • 大而深的SOTA骨干网并不总是最优选择,如果你的数据量不大,这些模型会快速过拟合,而且如果你计算资源有限,你也不能训练它们。

  • 一种比较好的做法是,最开始选择简单的网络,在验证集上监控性能变化,逐步增加模型复杂度。

  • 如果你计划把你的方案部署到网页端,则必须要考虑模型大小、内存消耗、推断时间等。

06

设计一个鲁棒的流程

作者模型训练的5个步骤如下:

特别值得一提的是,作者使用了种类繁多的增广方法:高斯噪声和模糊、运动模糊、模拟随机下雨、颜色偏移、随机颜色(亮度、饱和度、色度)改变、锐化、随机透视变换、伸缩变换、随机旋转、仿射变换、随机遮挡。

具体细节可以在代码中查看:

https://github.com/ahmedbesbes/whales-classification

07

通用训练技巧

  • 固定种子保证模型可重复性

在PyTorch中可以使用如下代码:

import random
import numpy as np
import torch
random.seed(seed)
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
  • Adam是安全的优化器,但不要忘记把weight decay 设置为非0值。作者使用了1e-3。

  • 大量的数据增广确实改进了模型精度,强烈推荐使用:albumentations
    (https://github.com/albumentations-team/albumentations)

  • 选择合适的学习率调度方案,避免模型陷入局部极小值。作者使用了warmup 调度机制后面加cosine退火的方法。

  • Loss 和其他度量结果的监控,作者使用了 Tensorbaord。

  • 使用伪标签往往可以带来精度提升。这在kaggle竞赛中经常被使用,它使用训练好的模型预测测试数据的类别,将置信度很高(比如大于0.9概率)预测结果的样本加入训练样本,重新训练模型。

  • 最好有强大的硬件支持,训练速度快,可以快速验证改进策略。

  • 保存训练好的模型,跟踪并记录模型的表现。

08

最终方案:从嵌入到元嵌入

作者训练了两个模型,如下:

作者将此二模型联合起来,使用元嵌入的方法,这是一种在NLP中经常使用的方法。

学习要点:

  • 元嵌入特征连接的方式在模型差异比较大的时候能够提供有意义的结果,比如:骨干网不同(resnet34 vs densenet121), 图像输入大小不同 (480 vs 620), 正则化模式不同 (dropout vs no dropout)的模型;

  • 每一个不同的基模型看到不同的事物,把它们联合能够产生新的增强混合模型(这个解释太朴素了——CV君)。

   传送门

原文链接:

A Hacker’s Guide to Efficiently Train Deep Learning Models

https://towardsdatascience.com/a-hackers-guide-to-efficiently-train-deep-learning-models-b2cccbd1bc0a

开源地址:

https://github.com/ahmedbesbes/whales-classification

更多阅读:

文中作者使用了ArcFace Loss,今年一种新的更好的Loss出现了:

旷视提出Circle Loss,革新深度特征学习范式 |CVPR 2020 Oral

如备注:目标检测

细分方向交流群

专业包括目标检测、目标跟踪、图像增强、强化学习、模型压缩、视频理解、人脸技术、三维视觉、SLAM、GAN、GNN等,

若已为CV君其他账号好友请直接私信。

我爱计算机视觉

微信号:aicvml

QQ群:805388940

微博知乎:@我爱计算机视觉

投稿:amos@52cv.net

网站:www.52cv.net

在看,让更多人看到  

步步为营!高手教你如何有效使用深度学习解决实际问题相关推荐

  1. 教你如何挑选深度学习GPU

    教你如何挑选深度学习GPU 即将进入 2018 年,随着硬件的更新换代,越来越多的机器学习从业者又开始面临选择 GPU 的难题.正如我们所知,机器学习的成功与否很大程度上取决于硬件的承载能力.在今年 ...

  2. 人工神经网络理论、设计及应用_TensorFlow深度学习应用实践:教你如何掌握深度学习模型及应用...

    前言 通过TensorFlow图像处理,全面掌握深度学习模型及应用. 全面深入讲解反馈神经网络和卷积神经网络理论体系. 结合深度学习实际案例的实现,掌握TensorFlow程序设计方法和技巧. 着重深 ...

  3. 谷歌开放语音命令数据集,助力初学者利用深度学习解决音频识别问题

    语音命令数据集地址:http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz 音频识别教程地址:https://www.tens ...

  4. 用深度学习解决大规模文本分类问题

     用深度学习解决大规模文本分类问题 人工智能头条 2017-03-27 22:14:22 淘宝 阅读(228) 评论(0) 声明:本文由入驻搜狐公众平台的作者撰写,除搜狐官方账号外,观点仅代表作者 ...

  5. 深度学习解决NLP问题:语义相似度计算

    深度学习解决NLP问题:语义相似度计算 参考文章: (1)深度学习解决NLP问题:语义相似度计算 (2)https://www.cnblogs.com/qniguoym/p/7772561.html ...

  6. 深度学习解决机器阅读理解任务的研究进展

    /*版权声明:可以任意转载,转载时请标明文章原始出处和作者信息.*/ author: 张俊林 关于阅读理解,相信大家都不陌生,我们接受的传统语文教育中阅读理解是非常常规的考试内容,一般形式就是给你一篇 ...

  7. 用深度学习解决旅行推销员问题,研究者走到哪一步了?

    来源:机器之心 本文约2600字,建议阅读9分钟 本文分析了深度学习在路由问题方面的最新进展,并提供了新的方向来启发今后的研究. 最近,针对旅行推销员等组合优化问题开发神经网络驱动的求解器引起了学术界 ...

  8. AI 三大教父齐聚深度学习峰会,讨论尖端研究进展

    来源:36氪 概要:近日,深度学习峰会正在加拿大蒙特利尔举行,有史以来第一次3位AI教父:Yoshua Bengio.Yann LeCun以及 Geoffrey Hinton聚在了一起出席RE•WOR ...

  9. 看不懂花书?博士教你如何深入深度学习,从编程基础到完整的项目实战

    转眼2020年已过去三分之一,大家都知道今年就业形势不乐观,不过即便如此,现在依然是AI招聘的热门季.疫情过后,AI行业注定会迎来一波大爆发. 近几年,各大企业也开始越来越重视人工智能方向的发展,比如 ...

最新文章

  1. 使用GPG校验sign签名
  2. 苹果菠萝笔html5游戏在线玩,苹果菠萝笔游戏
  3. 二叉树的蛇形遍历 leetcode 103
  4. css两张图片怎么合在一起_web前端入门到实战:纯CSS实现两个球相交的粘粘效果...
  5. java连接Oracle和PostGreSQL
  6. r如何查询mysql中的数据类型_MySQL-mysql中的数据类型
  7. Windows说明Linux分区和挂载点
  8. EtherCAT总线运动控制器中简单易用的直线插补
  9. 进程间通信IPC(一)pipe fifo mmap
  10. 爬虫 裁判文书网爬取part2
  11. VIVADO安装问题
  12. 如何查看本地服务器名称
  13. ESP32+阿里云+vscode_Pio
  14. [前端系列]vue3修改模板变量间隔符
  15. c#餐饮系统打印机_C#打印机操作类
  16. 采油工技能鉴定高级工计算机6,采油工技师、高级技师技能鉴定题库(宝典).doc...
  17. 客如云第二届开放平台大会 餐饮零售业新升级再赋能
  18. linux用cat命令创建一个文件,用cat在命令行创建文件
  19. unity android服务器端,【深圳Unity3D培训】 Android客户端与PC服务器实现Socket通信
  20. 服务器端渲染(SSR)和客户端渲染

热门文章

  1. 分解质因数(优中再优化)
  2. 激光器安规详细解读 - 一级 - 并以940波长为例
  3. 【环境搭建000】详细图解ubuntu 上安装配置eclips
  4. python跳过ssl验证_Python SSL证书验证问题解决方案
  5. 用python打印九九乘法表while_利用Python循环(包括whilefor)各种打印九九乘法表的实例...
  6. 50个linux初学者必须掌握的命令
  7. python中字母用什么表示_python中字母与ascii码的相互转换
  8. java泛型dao,泛型DAO模式在JavaWeb开发中的应用_孟晨.pdf
  9. 计算机桌面图标的排列,如何进行桌面图标排列 让你的桌面一秒变酷炫【图文教程】...
  10. 十进制转化为二进制_使用Windows 10内置计算器,将十进制数快速转换为二进制数,试试...