来自 | 知乎

地址 | https://www.zhihu.com/question/50519680

编辑 | 机器学习算法与自然语言处理公众号

本文仅作学术分享,若侵权,请联系后台删文处理

如何理解soft target这一做法?

题主最近在研究如何压缩一个ensemble模型。读了hinton的distill dark knowledge的文章,微软出的通过dnn提取rnn的文章,发现通过soft target来利用小模型提取大ensemble里的信息似乎是个蛮靠谱的想法。但是有点想法没弄清。主要问题是:

1. 小模型可以得到原ensemble的泛化性能。但题主觉得因为原ensemble可以预测特征空间中任意点的target,所以进行学习的小模型可以获得更多的训练样本,进而得到和ensemble一样的泛化性能。但hinton文章中说似乎通过soft target训练的时候,哪怕给后来的模型一小部分数据(3%),也可以获得一样的泛化性能。同时Li的文章用了大量未标记的样本,通过ensemble预测logit然后训练,效果反而一般般。所以似乎是我的想法走偏了,但为什么呢。。

2. 为什么softtarget效果好,还有没有别的解释?文章看下来总感觉并没有那么convincing。不知其所以然。。
谢谢大家了~

----------------题主的理解是------------

训练时只是令大小网络在一小部分数据点上输出的函数值相匹配。那么可能的理解是小模型自带防过拟合功能,拟合出来的函数比较平滑所以也自带泛化性能。如果这是对的,那么用和原先一样的大模型来distill会出现过拟合。这就是为什么不会用大模型来transfer knowledge的原因么?

最近看起来soft target和label smoothing有着千丝万缕的关系。但一个是sample wise,一个是label wise。


周博磊

机器学习、深度学习(Deep Learning)、人工智能 话题的优秀回答者

https://www.zhihu.com/question/50519680/answer/136359743

我觉得可以从大模型对原始的标定空间进行data augmentation来理解hinton的dark knowledge。

图片样本的标定一般都是离散的,就是一张图只有一个类别的标定。但是很多时候类别之间有很大的相似性,这个相似性信息并没有被离散的标定体现出来。比如说一个图片分类数据库有猫,狗,自行车,卡车,汽车等类别,图片的标定只给了这张图A是狗,这张图B是卡车,而并没有给出图A里面的狗会更像猫,图B里面的卡车会更像汽车。

大模型通过传统的cross entropy loss,训练过程中其实是可以把训练数据里面类别之间的关联性信息学习出来。比如说在testing的时候,给张狗的图片,模型的output probability可能是狗0.9,猫0.09,卡车0.005,汽车0.005,图片更容易被错分到猫那一类。knowledge distillation的大致做法是把所有训练样本都再feeforward一遍训练好的大模型,那么原本图片样本的离散标定,就变为一个大模型预测出来的类别probability,这个probability更好地表征了类与类之间的相似性。其实这可以看成是利用训练好的大模型对原始的标定空间进行了一次data augmentation(不过这个augmentation是发生在label space,而不是图片本身)。这里论文里还提了对probability降低temperature等trick,其实是想使得类与类之间的关联信息更明显。

小模型因为有了大模型帮忙提取出的标定空间的更多关联信息,所以能更好的进行学习。可以预见,小模型的performance并不会超过大模型,但是会比原来在离散标定空间上面训练的效果好。

题外话:我必须得再感慨一下geoffrey hinton的超人洞察力。dark knowledge,dropout等简单有效的trick都是神来之笔。


Naiyan Wang

深度学习(Deep Learning)、机器学习、人工智能 话题的优秀回答者

https://www.zhihu.com/question/50519680/answer/136363665

这个问题真的很有意思,我也曾经花了很多时间思考这个问题,可以有一些思路可以和大家分享。

一句话简言之,Knowledge Distill是一种简单弥补分类问题监督信号不足的办法。

传统的分类问题,模型的目标是将输入的特征映射到输出空间的一个点上,例如在著名的Imagenet比赛中,就是要将所有可能的输入图片映射到输出空间的1000个点上。这么做的话这1000个点中的每一个点是一个one hot编码的类别信息。这样一个label能提供的监督信息只有log(class)这么多bit。然而在KD中,我们可以使用teacher model对于每个样本输出一个连续的label分布,这样可以利用的监督信息就远比one hot的多了。

另外一个角度的理解,大家可以想象如果只有label这样的一个目标的话,那么这个模型的目标就是把训练样本中每一类的样本强制映射到同一个点上,这样其实对于训练很有帮助的类内variance和类间distance就损失掉了。然而使用teacher model的输出可以恢复出这方面的信息。具体的举例就像是paper中讲的, 猫和狗的距离比猫和桌子要近,同时如果一个动物确实长得像猫又像狗,那么它是可以给两类都提供监督。

综上所述,KD的核心思想在于"打散"原来压缩到了一个点的监督信息,让student模型的输出尽量match teacher模型的输出分布。其实要达到这个目标其实不一定使用teacher model,在数据标注或者采集的时候本身保留的不确定信息也可以帮助模型的训练。

当然KD本身还有很多局限,比如当类别少的时候效果就不太显著,对于非分类问题也不适用。我们目前有一些尝试试图突破这些局限性,提出一种通用的Knowledge Transfer的方案。希望一切顺利的话,可以早日成文和大家交流。:-)


YJango

机器学习、深度学习(Deep Learning) 话题的优秀回答者

https://www.zhihu.com/question/50519680/answer/136406661

给没用过该方法的朋友简单介绍:
这是一种深层神经网络的训练方法。

一、什么是distillation (或者用Hinton的话说,dark knowledge)(论文Distilling the Knowledge in a Neural Network)

1、训练大模型:先用hard target,也就是正常的label训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的output。
3、训练小模型,在小模型的基础上再加一个额外的soft target的loss function,通过lambda来调节两个loss functions的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。

二、什么是generalized distillation(privileged information和distillation的结合)(论文:Unifying distillation and privileged information)

区别仅在于大模型不再由同小模型一样的vector训练,而拥有额外信息的features。这时大模型叫做老师模型,小模型叫做学生模型,两个模型的大小不再是重点。重点是训练模型的input vector所含有的信息量。

三、为何要软化,为何好用的初步解释
由于我发过基于generalized distillation的application论文(Articulatory and spectrum features integration using generalized distillation framework),这里先给一个我用过的ppt的一页。(ppt其中提到的论文是Recurrent Neural Network Training with Dark Knowledge Transfer)

信息量:
hard target 包含的信息量(信息熵)很低,
soft target包含的信息量大,拥有不同类之间关系的信息(比如同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target 那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。)
软化:
问题是像左图的红色0.001这部分,在cross entropy的loss function中对于权重的更新贡献微乎其微,这样就起不到作用。把soft target软化(整体除以一个数值后再softmax),就可以达到右侧绿色的0.1这个数值。

这张ppt是一般的解释。如果想了解更深刻的话,请看深层神经网络 · 超智能体关于深层为什么比浅层有效的解释,从中找到为什么要one hot vector编码,然而接着往下看。

四、进一步解释
我假设你已经看完了深层神经网络 · 超智能体关于深层为什么比浅层有效的解释,并找到了为什么要one hot vector编码

那你应该看到了很多关于学习的关键。拿其中两句来说:学习是将变体拆成因素附带关系的过程;变体(variation)越少,拟合难度越低,熵越低。
在神经网络的隐藏层中可以通过加入更多隐藏层节点来拟合变体,降低熵和拟合难度。

然而有两个瓶颈,就是输入层和输出层是由数据限制的。输入层的瓶颈也可以来靠稀疏表示来弥补一些(是神经网络之外的额外一层)。而输出层的瓶颈就可以靠soft target来弥补。

原因在于即便是输出层也是有变体的。如何消除这一层的变体?
举个例子,如果要分类10个不同物体的图片,完全可以只输出一个节点,用一维输出来表示。得到1就是第一个物体,得到9就是第9个物体。但是这样做的话,输出层的这一个节点就有10个变体,会使建模时的熵增加。但是若用10个节点来分别表示,那么每个节点就只有一个可能的取值(也就是没有变体),熵就可以降到最低。可是问题又来了,如何确保这10个类就是所有变体?该问题是和caffe 每个样本对应多个label?的问题是类似的,并没有把所有变体都摊开。

如何解决这种分类问题:
五分类问题
样本1: 2,3(样本1属于第二类和第三类)
样本2: 1,3,4(样本2属于第一类,第三类和第四类)

既属于第二类又属于第三类的样本就和既像驴又像马的样本图片一样。
一种解决方式是重新编码,比如把属于2,3类的数据对应一个新的label(该label表示既是2又是3)。

把所有变体全部摊开,比如有n个类,输出层的节点个数会变成2^n-1(-1是排除哪一类都不是的情况)。
soft target也有类似的功效,就是Naiyan Wang的答案中提到的“一句话简言之,Knowledge Distill是一种简单弥补分类问题监督信号不足的办法”。在上述的问题中仅仅是“是 or 不是”,所以编码的基底为2,形成2^n。soft target的基底就不是2了,而是0-1的实数个,还要保证所有类的总和为1(我数学底子差,不知道该怎么计算这个个数)。

soft target是加入了关于想要拟合的mapping的prior knowledge,降低了神经网络的搜索空间,从而获得了泛化(generalization)能力。
加入prior knowledge也是设计神经网络的核心

五、实验现象
soft target的作用在于generalization。同dropout、L2 regularization、pre-train有相同作用。
这其实也是我要在知乎Live深层学习入门误区想要分享内容的一部分。
简单说
dropout是阻碍神经网络学习过多训练集pattern的方法
L2 regularization是强制让神经网络的所有节点均摊变体的方法。
pretrain和soft target的方式比较接近,是加入prior knowledge,降低搜索空间的方法。

证据:
1,Recurrent Neural Network Training with Dark Knowledge Transfer中还测试了用pretrain soft target的方式来加入prior knowledge,效果还比distillation的方式好一点点(毕竟distillation会过分限制网络权重的更新走向)

2,我才投的期刊论文有做过两者的比较。当没有pretrain的时候用distillation,效果会提升。但是当已经应用了pretrain后还用distillation,效果基本不变(有稍微稍微的变差)

3,Distilling the Knowledge in a Neural Network我记得好像并没有提到是否有用dropout和L2也来训练small model。我的实验结果是:
如果用了dropout和L2,同样feature vector的soft target来训练的distillation不会给小模型什么提升(往往不如不用。有额外feature vector的soft target的会提升)。

如果不用dropout和L2,而单用distillation,那么就会达到dropout和L2类似的泛化能力,些许不如dropout。
也就是说dropout和L2还有pretrain还有distillation其实对网络都有相同功效。是有提升上限的。Unifying distillation and privileged information论文是没有用dropout和L2,我看过他的实验代码。如果用了dropout和L2,提升不会像图表显示的那么大。

4,你注意Hinton的实验结果所显示提升也是十分有限。压缩网络的前提是神经网络中的节点是由冗余的,而关于压缩神经网络的研究,我建议是从每层中权重W的分析入手。有很多关于这方面的论文。

如何理解soft target这一做法?相关推荐

  1. label smooth标签平滑的理解

    今天我们来聊一聊label smooth这个tricks,标签平滑已经成为众所周知的机器学习或者说深度学习的正则化技巧.标签平滑--label smooth regularization作为一种简单的 ...

  2. 省内读大学与省外读大学的区别?看完扎心了…

    大一上学期过半,你在省内读书还是走出省了呢?省内党羡慕省外读书的自由,省外党则羡慕省内读书的同学可以常回家看看? 那么究竟是省内读大学好,还是省外读大学好呢?先来听听看省外党怎么说: @李哲 从此之后 ...

  3. 基于matlab的fisher线性判别及感知器判别_Deep Domain Adaptation论文集(一):基于label迁移知识...

    本系列简单梳理一下<Deep Visual Domain Adaptation: A Survey>这篇综述文章的内容,囊括了现在用深度网络做领域自适应DA(Domain Adaptati ...

  4. 论文笔记:Distilling the Knowledge

    原文:Distilling the Knowledge in a Neural Network Distilling the Knowledge 1.四个问题 要解决什么问题? 神经网络压缩. 我们都 ...

  5. Face Model Compression by Distilling Knowledge from Neurons 论文理解

    引入 一. 背景 为保证人脸识别技术的精度要求,需要大而复杂的单个或者组合的深度神经网络实现. 该技术欲迁移至移动终端与嵌入式设备中. 二. 解决方法 运用模型压缩技术,用小的网络去拟合大量数据.大型 ...

  6. 理解注意力机制的好文二

    注意力模型最近几年在深度学习各个领域被广泛使用,无论是图像处理.语音识别还是自然语言处理的各种不同类型的任务中,都很容易遇到注意力模型的身影.所以,了解注意力机制的工作原理对于关注深度学习技术发展的技 ...

  7. Soft Labels for Ordinal Regression阅读笔记

    Soft Labels for Ordinal Regression CVPR-2019 Abstract 提出了简单有效的方法约束类别之间的关系(其实就是在输入的label中考虑到类别之间的顺序关系 ...

  8. soft attention and self attention

    注意力模型最近几年在深度学习各个领域被广泛使用,无论是图像处理.语音识别还是自然语言处理的各种不同类型的任务中,都很容易遇到注意力模型的身影.所以,了解注意力机制的工作原理对于关注深度学习技术发展的技 ...

  9. 深入理解attention机制

    深入理解attention机制 1. 前言 2. attention机制的产生 3. attention机制的发展 4. attention机制的原理 5. attention的应用 参考文献 1. ...

最新文章

  1. java date 格式化_Date类日期格式化
  2. java实现插入排序算法 附单元测试源码
  3. Oracle 表的创建 及相关參数
  4. 华为鸿蒙osbeta发布会,华为鸿蒙 OS Beta 3
  5. linux寻找依赖文件
  6. 96309245通讯异常工行_工商银行信息代码 96309245 是什么意思
  7. sata和sas硬盘Linux,SAS硬盘和SATA硬盘最大的区别是什么?
  8. 「hdu6638」Snowy Smile【稀疏矩阵最大子矩阵和】
  9. python中iter是什么意思_Python __iter__ 深入理解
  10. 终于找到了PyMuPDF不能提取文字的原因……它只是个包装
  11. Windows10搭建turn服务器
  12. 2021-09-10 LeetCode1894-找到需要补充粉笔的学生编号(每日一题)
  13. Uboot启动logo修改
  14. Nodejs+socket.io 搭建个人的网页聊天室
  15. Java后端开发面试7大核心总结,为你保驾护航金九银十!
  16. 校园导游咨询系统(数据结构课程设计)
  17. 联想m415节能产品认证证书_节能认证
  18. vscode 单击跳转_vscode中ctrl+鼠标左键不能跳转
  19. 制作每日疫情通报省份地图
  20. 怎么使用计算机播放音乐,怎么用Apple Watch控制电脑播放音乐?

热门文章

  1. android监听器在哪里创建,[转载]android开发中创建按钮事件监听器的几种方法
  2. 对比四种爬虫定位元素方法,你更爱哪个?
  3. RISC-V 正在成为芯片世界中的 Linux
  4. Python中的元编程:一个关于修饰器和元类的简单教程
  5. 从事JAVA 20年最终却败给了Python,哭了!
  6. 谷歌、阿里们的杀手锏:三大领域,十大深度学习CTR模型演化图谱
  7. Python超越Java,Rust持续称王!Stack Overflow 2019开发者报告
  8. 2018最后一个月的Python热文Top10!赶紧学起来~
  9. 数据科学家必须要掌握的5种聚类算法
  10. 监管AI?吴恩达跟马斯克想到一块去了