Distilling the Knowledge in a Neural Network

这篇介绍一下Hinton大神在15年做的一个黑科技技术,Hinton在一些报告中称之为Dark Knowledge,技术上一般叫做知识蒸馏(Knowledge Distillation)。核心思想是通过迁移知识,从而通过训练好的大模型得到更加适合推理的小模型。这个概念最早在06年的Paper: Model Compression中, Caruana提出一种将大模型学习到的函数压缩进更小更快的模型,而获得可以匹敌大模型结果的方法。

重点idea就是提出用soft target来辅助hard target一起训练,而soft target来自于大模型的预测输出。这里有人会问,明明true label(hard target)是完全正确的,为什么还要soft target呢?

hard target 包含的信息量(信息熵)很低,soft target包含的信息量大,拥有不同类之间关系的信息(比如同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target 那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。)[5]

这样的好处是,这个图像可能更像驴,而不会去像汽车或者狗之类的,而这样的soft信息存在于概率中,以及label之间的高低相似性都存在于soft target中。但是如果soft targe是像这样的信息[0.98 0.01 0.01],就意义不大了,所以需要在softmax中增加温度参数T(这个设置在最终训练完之后的推理中是不需要的)

qi=exp(zi/T)Σjexp(zj/T)q_i=\frac{exp(z_i/T)}{\Sigma_jexp(z_j/T)}qi​=Σj​exp(zj​/T)exp(zi​/T)​

Loss是两者的结合,Hinton认为,最好的训练目标函数就是这样,并且第一个目标函数的权重要大一些。
L=αL(soft)+(1−α)L(hard)L = \alpha L^{(soft)}+(1-\alpha)L^{(hard)}L=αL(soft)+(1−α)L(hard)

算法示意图如下[5]:

1、训练大模型:先用hard target,也就是正常的label训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的output。
3、训练小模型,在小模型的基础上再加一个额外的soft target的loss function,通过lambda来调节两个loss functions的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。

在线蒸馏 codistillation

在分布式训练任务下,提出了一种替代标准SGD训练NN模型的方法codistillation,是一种非同步的算法,事实上有很多个Wieght的副本在独立训练,他可以有效“解决”机器增加但线性度不增加的问题,实验中还有一些数据表面可以比标准的SGD收敛更快。

也是distill的思想,但是因为是重头训练,所以什么是teacher model呢?作者提出用所有模型的预测平均作为teacher model,然后作为soft target来训练每一个模型。

在这篇论文中,使用codistillation来指代执行的distillation:

  • 所有模型使用相同的架构;
  • 使用相同的数据集来训练所有模型;
  • 任何模型完全收敛之前使用训练期间的distillation loss。

算法原理如下:


distillation loss作者提到了可以是平方距离,或者是KL-divergence,但是作者采用的是cross entropy。

分为两个阶段,第一个是独立的SGD更新阶段,这个阶段是不需要去同步的,因此非常高效。第二阶段是codistill阶段,就是用每一个model的平均预测结果来作为soft target训练每一个独立的model,有趣的是作者说这样会训练出来一堆不同的model(只要求他们表现接近,并不能强求他们的Weight一样),但是这些model在loss上没什么区别。、

作者也表明,实际上可以和标准的同步SGD来结合,也就是分组——组内用同步SGD训练model副本,然后组间用codistill来训练(只交换预测结果,非常小)。另外,作者表示虽然他们在仿真实现上是传输了所有的模型副本到所有node上,为了得到soft预测结果,但实际上可以只传输预测结果即可(我猜可能是框架支持不方便?)

实验结果:

重点看下Imagenet,16K Batchsize用两路训练比用一路训练收敛快。

参考资料

[1] Hinton胶囊网络后最新研究:用“在线蒸馏”训练大规模分布式神经网络
[2] G.Hinton Dark Knowledge
[3] 2006, Model Compression
[4] 如何让你的深度神经网络跑得更快
[5] 如何理解soft target这一做法?

深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network),在线蒸馏相关推荐

  1. Paper:《Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏》翻译与解读

    Paper:<Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏>翻译与解读 目录 <Distilling the Know ...

  2. Distilling the Knowledge in a Neural Network 论文笔记蒸馏

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/bryant_meng/article/ ...

  3. 【Distilling】《Distilling the Knowledge in a Neural Network》

    arXiv-2015 In NIPS Deep Learning Workshop, 2014 文章目录 1 Background and Motivation 2 Conceptual block ...

  4. 《Distilling the Knowledge in a Neural Network》 论文阅读笔记

    原文链接:https://arxiv.org/abs/1503.02531   第一次接触这篇文章是在做网络结构的时候,对于神经网络加速,知识蒸馏也算是一种方法,当时连同剪纸等都是网络压缩的内容,觉得 ...

  5. Distilling the Knowledge in a Neural Network阅读笔记

    文章目录 Abstract Introduction Distillation Preliminary experiments on MNIST Experiments on speech recog ...

  6. 【论文翻译_知识蒸馏】Distilling Holistic Knowledge with Graph Neural Networks

    (以下的"提取"都可以替换为"蒸馏"),收录于ICCV2021 摘要 知识提炼(KD)旨在将知识从一个更大的优化教师网络转移到一个更小的可学习学生网络.现有的知 ...

  7. 推荐系统遇上深度学习(十五)--强化学习在京东推荐中的探索

    强化学习在各个公司的推荐系统中已经有过探索,包括阿里.京东等.之前在美团做过的一个引导语推荐项目,背后也是基于强化学习算法.本文,我们先来看一下强化学习是如何在京东推荐中进行探索的. 本文来自于pap ...

  8. 深度解析(十五)哈夫曼树

    哈夫曼树(一)之 C语言详解 本章介绍哈夫曼树.和以往一样,本文会先对哈夫曼树的理论知识进行简单介绍,然后给出C语言的实现.后续再分别给出C++和Java版本的实现:实现的语言虽不同,但是原理如出一辙 ...

  9. 深度学习方法(五):卷积神经网络CNN经典模型整理Lenet,Alexnet,Googlenet,VGG,Deep Residual Learning

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld.  技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 关于卷积神经网络CNN,网络和文献 ...

  10. 深度学习方法(五):卷积神经网络CNN经典模型整理Lenet,Alexnet,Googlenet,VGG,Deep Residual Learning...

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 关于卷积神经网络CNN,网络和文献中 ...

最新文章

  1. Linux学习-Xshell断开连接程序依然运行
  2. Forward+ Shading架构
  3. python爬虫采集网站数据
  4. 编译器构造概述(详细)
  5. 【遥感数字图像处理】基础知识:第二章 遥感知识回顾、遥感数字图像处理基础知识
  6. 学习easyui疑问(二)
  7. 使用wwise音效引擎的好处
  8. toStringequals方法
  9. MySQLl数据量不一样,导致走不同的索引
  10. python编程语言一览_编程语言大汇总(Part Ⅰ)
  11. 用电机进行简单的PID参数整定
  12. 服务器硬盘容量为0,硬盘容量不一样 raid0 扩容也可以很自如?
  13. 简单说说路由器和交换机的区别
  14. 无线桥连后不能访问服务器,路由器设置无线桥接后不能登录副路由器怎么办?...
  15. 哥尼斯堡的“七桥问题”(C++)
  16. oracle asm omf,Oracle Managed Files,OMF
  17. Tensorflow小技巧整理:修改张量特定元素的值
  18. CESM模式及其各个分量模式介绍
  19. 彻底征服 React.js + Flux + Redux【讲师辅导】-曾亮-专题视频课程
  20. PTA 使我精神焕发

热门文章

  1. ODL(C版本)安装过程
  2. python画彩虹圈_javascript – 如何使用HTML5画布生成彩虹圈?
  3. 让笔记本的无线网卡指示灯不再狂闪的方法
  4. 常用GIS(高清卫星影像、DEM)数据下载
  5. EXEL表格读取 按键精灵
  6. JavaScript(JS) date.getDate()
  7. NPDP第七章:产品生命周期管理
  8. 人类学家胡家奇谈科技发展:让它回归理性
  9. URAL 1389 Roadworks 贪心
  10. html版贪吃蛇的项目计划书,自动贪吃蛇.html