在传统的机器学习中,为了获得最先进的(SOTA)性能,我们经常训练一系列整合模型来克服单个模型的弱点。 但是,要获得SOTA性能,通常需要使用具有数百万个参数的大型模型进行大量计算。 SOTA模型(例如VGG16 / 19,ResNet50)分别具有138+百万和23+百万个参数。 在边缘设备部署这些模型是不可行的。

智能手机和IoT传感器等边缘设备是资源受限的设备,无法在不影响设备性能的情况下进行训练或实时推断。 因此,研究集中在将大型模型压缩为小型紧凑的模型,将其部署在边缘设备时性能损失最小至零。

以下是一些可用的模型压缩技术,尽管它不限于以下内容:

· 修剪和量化

· 低阶分解

· 神经网络架构搜索(NAS)

· 知识蒸馏

在这篇文章中,重点将放在[1]提出的知识蒸馏上,参考链接[2]提供了上面列出的模型压缩技术列表的详尽概述。

知识蒸馏

知识蒸馏是利用从一个大型模型或模型集合中提取的知识来训练一个紧凑的神经网络。利用这些知识,我们可以在不严重影响紧凑模型性能的情况下,有效地训练小型紧凑模型。

大、小模型

我们称大模型或模型集合为繁琐模型或教师网络,而称小而紧凑的模型为学生网络。

一个简单的类比是,一个大脑小巧紧凑的学生为了考试而学习,他试图从老师那里吸收尽可能多的信息。然而老师只是教所有的东西,学生不知道在考试中会出哪些问题,尽力吸收所有的东西。

在这里,压缩是通过将知识从教师中提取到学生中而进行的。

在提取知识之前,繁琐的模型或教师网络应达到SOTA性能,此模型由于其存储数据的能力而通常过拟合。 尽管过拟合,但繁琐的模型也应该很好地推广到新数据。 繁琐模型的目的是使正确类别的平均对数概率最大化。 较可能正确的类别将被分配较高的概率得分,而错误的类别将被赋予较低的概率。

下面的示例显示了在给定的"鹿"图像上进行推理时以及softmax之后的结果。下图。要获得预测,我们采用最大类概率评分的argmax,这将使我们有60%的机会是正确的。

然而,鉴于上面的图2。(为了说明的目的),我们知道与"船"相比,"马"与"鹿"非常相似。因此,在推断过程中,我们有60%是正确的,39%是错误的。由于"鹿"与"马"之间存在一定的空间相似性,因此网络预测"马"的准确性是不容置疑的。如果在网络中提供"我认为这幅图60%是鹿,39%是马"的信息,如[deer: 0.6, horse: 0.39, ship: 0.01],那么网络就会提供更多的信息(高熵)。使用类概率作为目标类比仅仅使用原始目标提供了更多的信息。

蒸馏

教师将预测类别概率的知识提取给学生作为"软目标"。这些数据集又称为"转移集",其目标为教师提供的类别概率,如上图所示。蒸馏过程是通过在softmax函数中引入一个超参数T(温度)来进行的,这样教师模型就可以为学生模型生成一个适当的传递集目标的软目标集合。

软目标擅长帮助模型泛化,并且可以充当正则化函数来防止模型过于自信。

训练教师和学生模型

首先,我们训练繁琐/教师模型,因为我们要求繁琐的模型很好地归纳为新数据。 在蒸馏过程中,学生模型目标函数是两个不同目标函数Loss1和Loss2的加权平均值。

loss1 软目标的交叉熵损失

温度T > 1乘以权重参数alpha的教师q和学生p的两个温度softmax的交叉熵损失(CE)。

loss2 硬目标的交叉熵损失

正确标签和T = 1的学生硬目标的交叉熵(CE)损失。Loss2很少注意(1- alpha)学生模型为匹配软目标而制定的硬目标(student_pred) q来自教师模型。

学生模型的目标是蒸馏损失,它是Loss1和Loss2之和。

然后在训练学生模型时,以最大程度地减少其蒸馏损失。

结果

MNIST实验

下表1是论文[1]的结果,该论文显示了使用在MNIST数据集上训练了60,000个训练案例的教师、学生和提炼模型的性能。 所有模型都是两层神经网络,分别具有1200、800和800个神经元,分别用于教师,学生和提炼模型。 当使用精简模型与学生模型进行比较时,温度设置为20时,教师和精简模型之间的测试误差相当。但是,仅使用具有硬目标的学生模型时,其推广性就变的很差。

语音识别实验

下表2是论文[1]的另一个结果。 教师模型是由85M参数组成的语音模型,该参数是根据2000个小时的英语口语数据进行训练的,其中包含大约700M的训练示例。 表2中的第一行是在100%的训练示例上训练的基线模型,其准确性为58.9%。 第二行仅使用3%的训练示例进行训练,这会导致严重的过度拟合。最后,第三行是用3%的训练样本用同样的3%的软目标训练得到的同样的语音模型,只用3%的训练数据就可以达到57%的准确率。

结论

知识蒸馏是一种用于将计算带到边缘设备的模型压缩技术。 目标是拥有一个紧凑的小型模型来模仿繁琐模型的性能。 这是通过使用软目标来实现的,这些目标充当正则化器,以允许小型紧凑的学生模型泛化并从教师模型中恢复几乎所有信息。

根据Statista[3]的数据,到2025年,联网设备的安装总数预计将达到215亿。随着大量的边缘设备的出现,为边缘设备带来计算是使边缘设备更智能的一个日益增长的挑战。知识蒸馏允许我们执行模型压缩而不影响性能的边缘设备。

引用

[1] Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 (2015).

[2] An overview of model compression techniques for deep learning in space

[3] IoT number of connected devices worldwide

作者:Kelvin

deephub翻译组

睡眠 应该用 a加权 c加权_在神经网络中提取知识:学习用较小的模型学得更好...相关推荐

  1. 神经网络 mse一直不变_卷积神经网络中十大拍案叫绝的操作

    公众号关注 "DL-CVer" 设为 "星标",DLCV消息即可送达! 来自 | 知乎作者丨Justin ho来源丨https://zhuanlan.zhihu ...

  2. python 加载动图_在浏览器中使用TensorFlow.js和Python构建机器学习模型(附代码)...

    大数据文摘授权转载自数据派THU 作者:MOHD SANAD ZAKI RIZVI 本文主要介绍了: TensorFlow.js (deeplearn.js)使我们能够在浏览器中构建机器学习和深度学习 ...

  3. 提取图像感兴趣区域_从图像中提取感兴趣区域

    提取图像感兴趣区域 Welcome to the second post in this series where we talk about extracting regions of intere ...

  4. python中的连续比较是什么_在python中提取连续行之间的差异

    你的例子表明你想要在一对线之间进行比较.这与将其定义为line(n-1)-line(n)不同,后者将给出5个结果,而不是3个.在 结果也取决于你认为的差异.它是位置性的,还是仅仅基于奇数行中缺失的字母 ...

  5. java 克隆的作用_关于java中克隆的学习(一)

    java中的克隆,就是要复制对象,但为什么要用克隆呢?我们直接把对象赋值给其它同类型的实例不就行了吗?这就要从java的值传递和引用传递说起了. package dcr.study.test.poin ...

  6. .net 数字转汉字_[原创工具] 小熊汉字笔顺学习软件,查笔顺、学拼音、制作汉字英文数字字贴...

    点击右上角"设为星标"每日精彩内容,第一时间送达! 前言 今天带来的是原创软件.家里有上一二年级的小朋友有福了!家里有打印机的可以把设置好的字帖打印出来,小朋友即可临摹.赶紧下载使 ...

  7. 机器学习算法_机器学习算法中分类知识总结!

    ↑↑↑关注后"星标"Datawhale每日干货 & 每月组队学习,不错过Datawhale干货 译者:张峰,Datawhale成员 本文将介绍机器学习算法中非常重要的知识- ...

  8. access查询出生日期格式转换_从身份证中提取出生日期的3个方法和计算年龄和星座的方法...

    在我们日常的工作当中,经常会遇到通过身份证来获取出生年月日的需求,今天就给大家介绍三种可以从身份证中提取出生年月日的方法. 我们都知道身份证不同的区域是有不同的含义的,代表出生年月日的数字是第7位到第 ...

  9. java中的io复用_从 Java 中的零拷贝到五种IO模型

    在之前的文章中,我们聊过了 Java 中的零拷贝,零拷贝就是指数据不会在内核空间和用户空间之间相互拷贝.这样就减少了内核态与用户态的切换,自然就很高效. 拷贝文件只是 IO 操作中一个特殊的情况,大多 ...

最新文章

  1. Linux 黑话解释:什么是定时任务
  2. 关于sendmail报错“did not issue MAIL/EXPN/VRFY/ETRN during connection to
  3. Neural Networks神经网络编程入门
  4. 四十七、第二份国外的Python考试(上篇)
  5. 使用Uploadify实现上传图片生成缩略图例子,实时显示进度条
  6. GITHUB来获得UE4源代码
  7. 生成Ipa安装包的plist文件后生成下载链接
  8. SharePoint 2010问题集锦 (2011.1)
  9. 在MT4上使用双线MACD指标源码
  10. 教大家一个快速批量去水印下载快手视频、图集的方法技巧
  11. 安川e1000中文说明书_安川E1000变频器维修故障代码说明书
  12. 用Tampermonkey真正屏蔽B站自己不感兴趣的视频
  13. 2018蓝桥杯B组国赛第四题 调手表(bfs)
  14. hibernate中的检索策略
  15. Linux下基础命令(二)
  16. 1124 Raffle for Weibo Followers(map)
  17. django之 报错(1146, “Table ‘demo2.web‘ doesn‘t exist“)
  18. cinder云硬盘type创建
  19. 【Xshell免费版,不用去找破解(ftp也一样)】
  20. html中hover的作用,hover在css中的作用是什么

热门文章

  1. nio 读取目录所有文件_在NIO.2中使用文件和目录
  2. maven项目 ant_将旧项目从Ant迁移到Maven的4个简单步骤
  3. 功能Java示例 第6部分–用作参数
  4. 在Jersey测试中模拟SecurityContext
  5. 无参数泛型方法反模式
  6. 将旧版本从Java EE 5减少到7
  7. 使用Apache Cassandra设置一个SpringData项目
  8. 性能,可伸缩性和活力
  9. orm框架选型问题_ORM问题
  10. 针对WildFly和EAP运行Java Mission Control和Flight Recorder