©PaperWeekly 原创 · 作者|燕皖

单位|渊亭科技

研究方向|计算机视觉、CNN

在监督学习中,模型都是需要有一个大量的有标签的数据集进行拟合,通常数据成本、人力成本都很高。而现实生活中,无标签的样本的收集相对就很容易很多。因此,近年来,利用大量的无标签样本和少量的有标签样本的半监督学习备受关注。

本文主要介绍一种半监督的方法——Self-training,其主要思路是:先利用有标签数据训练得到模型,然后对无标签数据进行预测,置信度高的数据可以用于加入训练集,继续训练,直到模型符合要求。首先介绍了两种经典的 Self-training 方法,然后介绍了 Self-training 在 Kaggle 比赛上的实践。

Pseudo-label

论文标题:The Simple and EfficientSemi-Supervised Learning Method for Deep Neural Networks

论文来源:ICML 2013

论文链接:http://deeplearning.net/wp-content/uploads/2013/03/pseudo_label_final.pdf

代码链接:https://github.com/iBelieveCJM/pseudo_label-pytorch

1.1 训练策略

Pseudo-label 是 2013 年提出的一个非常简单有效的Semi-Supervised Learning 方法,其主要思想是在一批有标签和无标签的图像上,同时训练一个模型。训练流程如下:

Step 1:首先,同时使用有标记和未标记的 data,以有监督的方式训练 pretrained model。总损失是有标记和无标记损失项的加权和,前面是有标签数据的损失部分,后面的无标签数据的损失部分,如下:

其中,y 代表已标记数据的标签,y′ 代表了未标记数据的伪标签。

通常,为了确保模型已经从标记的数据中学习了足够多的信息,alpha_t 在最初的 N epoch 中,设置为 0,然后逐渐增加到 M epoch 后保持不变。如下式:

Step 2:然后,用训练好的 model 对一批未标记图像进行预测,用最大置信度作为 Pseudo-label ;

Step 3: 最后将有标签和伪标签的数据一起进行 finetune,直到最终得到最优 model。

1.2 实验结果

文章指出用 600 个标记数据对神经网络进行训练,和增加 60000 个未标记的数据和伪标签。从下图可以明显看到,通过使用未标记数据和伪标签训练的模型具有更好地泛化能力。

Noisy Student

论文标题:Self-training with Noisy Studentimproves ImageNet classification

论文来源:CVPR 2020

论文链接:https://arxiv.org/abs/1911.04252

代码链接:https://github.com/google-research/noisystudent

Google AI 年提出了一种受 Knowledge Distillation 启发的半监督方法“Noisy Student”。

2.1 Introduction

这篇文章主要的方法简单说就是使用更大的未标记图像的数据集,其中大部分图像不属于 ImageNet 训练集分布,来提高 SOTA-ImageNet 的精度。

其核心思想是 train 两种不同的模型,即“Teacher”和“Student”。教师模型首先对标签图像进行训练,然后对未标记图像进行伪标签推断。这些伪标签可以是 soft-label,也可以通过使用 most confident 转换的 hard-label。

然后,将有标记和未标记的图像组合在一起,并根据这些组合的数据训练学生模型。利用 RandAugment 作为输入噪声的一种形式对图像进行增强,最后训练得到最优 model。

2.2 训练策略

对于一些有标签数据集 data1 和一些无标注数据集 data2

第一步:在有标签数据集上训练一个模型,称为 teacher;

第二步:利用第一步得到的模型,在未标注数据集上进行预测,softmax 输出结果是概率分布,一般称为称为 soft label,其只给出每个类别的 score,而非指定为具体某个类别,而 hard label 就是 one-hot 形式的取 max 后的结果,并且实验证明软标签更好一些;

第三步:将有标注数据集和伪标签数据集合并,然后利用 augmentation、droupout 等策略,基于这个大数据集进行训练一个新的 student 模型;

第四步:将学到的 student 当做 teacher 重新对无标注数据集进行打标签,回到第二步中,迭代直到得到最优 mdoel 为止。

2.3 实验

对于标准数据集,仍使用 ImageNet 2012 基准数据集;

未标注数据集来自于 JFT 数据集,它实际含有大约 3 亿张图片,尽管这些图片实际有真实标签,但我们此处不需要,只当做无标记图片数据集即可。

为了实现无标签图片类别的平衡,作者拿在 ImageNet 上训练的 EfficientNet-B0 对 JFT 数据集打标签,并剔除了标签信任度低于 0.3 的图片,对于每个类别,挑选具有最高信任度的 13 万张图片,对于不足 13 万张的类别,随机再复制一些。

最终结果如下,可见 Noisy Student 方法在这一数据集上将 SOTA 性能提高了一个点。

Global Wheat Detection上的实践

接下来,将从目前正在参加的 kaggle 比赛(Global Wheat Detection)全球小麦头检测来分析Semi-Supervised Learning在目标检测中的作用。

比赛链接:

https://www.kaggle.com/c/global-wheat-detection

在本竞赛中,将从室外的小麦植株图像(包括来自全球的小麦数据集)中检测出小麦植株的头部,训练数据集涵盖了多个区域,是来自欧洲(法国,英国,瑞士)和北美(加拿大)的 3,000 多张图像,测试数据包括来自澳大利亚,日本和中国的约 1,000 张图像。

下面是一些识别的小麦头图片,可以看到比赛困难点不仅仅是数据少,小麦头经常重叠、小麦头具有多种尺寸、小麦的外观颜色由于成熟度不同而各不相同,

3.1 训练策略

由于在 kaggle 图像检测的比赛当中对于测试集的图片我们是无法查看的,只有在提交后代码运行才能调用测试集,因此我们在 kaggle 比赛使用需要对 Pseudo-label 的方法做些修改。

Step 1:将有标签部分数据分为两份:训练集和测试集,并训练出最优的 model1

Step 2:用训练好的 model 1 对一批未标记图像(测试集)进行预测,制作伪标签的过程中可以使用 Noisy Student 的方法,即通过图像翻折、旋转、缩放等对图像进行扩增,以此提升我们制作的伪标签的准确度,然后对预测的标签进行筛选选择大于预测阈值的标签作为伪标签。

Step 3:最后将有标签的数据(训练集)和伪标签的数据(测试集)一起进行 finetune model 1,通过验证集选取 best model。

3.2 阈值选取

在目标检测任务中使用 Pseudo-label 方法的关键在于如何设置好预测阈值,由于一张图片当中具有多个目标,如果只是选择预测概率较高的结果作为标签,那么一张图中就会有许多目标就没有被标记出来被当作负样本。

这样子制作的标签假负例(FN)过多,但是阈值也不能偏低太低的话会引入一些错误的假正例(FP)所以目标检测任务中的预测概率阈值成为伪标签制作的一个关键,不能太高但同时也不能太低(太低的话会引入一些错误的标签)。

在比赛我得到的关于阈值选取的经验是,当图像中目标较多的情况下选取的阈值应该要小一些这样可以避免较多的假负例,反之在目标少的情况选择的阈值应大一些,还有一个比较有效的方法是利用在训练集上训练好的模型通过滑动阈值(自动逐个尝试)先搜索出模型在验证集上取得较好效果的预测阈值,再通过微调这个阈值测试出最适合制作伪标签的阈值。

在使用 Semi-Supervised Learning 成绩为: 0.7720 ,没使用是 0.7522,增加了 0.0198,效果可以说是相当的明显了,排名提升了一百多名。

结论

可以看到,不论是小数据集,还是大数据集,Self training 都是一种有效的涨点方法,尤其是,在像 Kaggle 这样的比赛中,相信这项技术是很有用的,因为通常即使是轻微的分数提高也能让你在排行榜上得到提升。

更多阅读

#投 稿 通 道#

 让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。

???? 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

???? 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site

• 所有文章配图,请单独在附件中发送

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

Self-training在目标检测任务上的实践相关推荐

  1. MobileFormer-在目标检测任务上怒涨8.6 AP,微软新作MobileFormer

    关注公众号,发现CV技术之美 0 写在前面 在本文中,作者提出了一个并行设计的双向连接MobileNet和Transformer的结构Mobile-Former.这种结构利用了MobileNet在局部 ...

  2. 激光雷达目标检测 (上)

    激光雷达目标检测 (上) **---- 转载自美团无人专送团队** 简介 安全性是自动驾驶中人们最关注的问题之一. 在算法层面,无人车对周围环境的准确感知是保证安全的基础,因此感知算法的精度十分重要. ...

  3. 【转载智车科技公众号(微信)】目标检测综述上(上)

    近年来,深度学习模型逐渐取代传统机器视觉方法而成为目标检测领域的主流算法,本系列文章将回顾早期的经典工作,并对较新的趋势做一个全景式的介绍,帮助读者对这一领域建立基本的认识. " 作者 | ...

  4. 图像语义分割和目标检测(上)

    语义分割是对图像在像素级别上进行分类的方法,在一张图像中,属于同一类的像素点都要被预测为相同的类,因此语义分割是从像素级别来理解图像.但是需要正确区分语义分割和实例分割,虽然他们在名称上很相似,但是他 ...

  5. 图像检测:目标检测(上)

    目标检测 检测图片中所有物体的类别标签位置 与其他任务的区别 区域卷积神经网络R-CNN 模型结构 : 按分类问题对待.模块一:提取物体区域(Region Proposal),不同位置,不同尺寸,数量 ...

  6. PaddleDetection研究报告——百度目标检测PP-YOLOE论文解读+实践应用

    最新发布 PP-YOLOE+,最高精度提升2.4% mAP,达到54.9% mAP,模型训练收敛速度提升3.75倍,端到端预测速度最高提升2.3倍:多个下游任务泛化性提升. PicoDet-NPU模型 ...

  7. 国外AI工程师讲述:深度学习与目标检测,理论和实践果然两码事

    背景故事 2018 年,当时我在工厂实习,我开始研究目标检测技术,因为我需要解决视觉检测问题. 这个问题需要在来自工业相机的图像流中检测许多不同的物体目标. 为了应对这一挑战,我首先尝试将分类与滑窗法 ...

  8. 【论文阅读】【三维目标检测】在Range view上做3D目标检测

    文章目录 BEV or Range View RangeDet: In Defense of Range View for LiDAR-based 3D Object Detection Range ...

  9. 学界 | 斯坦福提出高速视频目标检测系统NoScope:速度超现有CNN上千倍

    卷积神经网络在目标检测任务上已经取得了优良的表现,但它们的计算成本比较高.速度比较慢,不适用于大规模的实时视频处理.为了解决这个问题,斯坦福大学的几位研究者提出了一个名叫 NoScope 的系统,将目 ...

最新文章

  1. github上fork了别人的项目后,再同步更新别人的提交
  2. labview 随笔记录
  3. 浅谈ASP.NET内部机制(五)
  4. git常用命令之stash
  5. WSL端口映射到win
  6. 提升对前端的认知,不得不了解Web API的DOM和BOM
  7. Dart基础第8篇:函数、箭头函数 匿名函数 闭包等
  8. python之路——作业:Select FTP(仅供参考)
  9. XCodeGhost
  10. java ojdbc14.jar_ojdbc14_g.jar
  11. python工资条教程_批量发工资怎么操作_利用python轻松解决用邮箱批量发工资条...
  12. java语言c语言表情包_c语言表情包 - c语言微信表情包 - c语言QQ表情包 - 发表情 fabiaoqing.com...
  13. 中国到美国最安全的飞机航线
  14. OpenCV图像轮廓提取
  15. Android OTG U盘相关
  16. MAYA oceanShader/海洋(纹理)
  17. 7.消费者的确认机制
  18. java生成随机数的方法_Java获取随机数的3种方法
  19. 游戏乱码解决软件 NTLEA
  20. [开启C语言秃头之旅]扫雷游戏

热门文章

  1. ca证书 linux 导入_Linux CA证书服务器搭建
  2. 浅谈网页中的字体的设置
  3. css样式命名规则(仅供参考)
  4. 启动Tomcat的小细节--MyEclipse
  5. Python连接SQL Server数据库 - pymssql使用基础
  6. UE3 ExampleGame Android版无法运行解决方案
  7. 走在程序世界道路上的我___大一篇
  8. 第十节 字符串指针变量与字符数组的区别(十一)
  9. GetCurrentDirectory和SetCurrentDirectory函数
  10. html樱花飘落代码_爱心飘落特效