ECCV2018 | Learning Deep Representations with Probabilistic Knowledge Transfer

https://github.com/passalis/probabilistic_kt

1.传统知识蒸馏

最早的知识蒸馏方法专门针对分类任务进行设计,它们不能有效地用于其他特征学习的任务。 在本文中,作者提出了一种通过匹配数据在特征空间中的概率分布进行知识蒸馏(PKL)。该方法除了性能超越现有的蒸馏技术外, 还可以克服它们的一些局限性。包括:(1)可以实现直接转移不同架构/维度层之间的知识。(2)现有的蒸馏技术通常会忽略教师特征空间的几何形状,因为它们仅使学生网络学习教师网络的输出结果。而PKL算法能够有效地将教师模型的特征空间结构映射到学生的特征空间中,从而提高学生模型的准确性。PKL算法示意图如下所示。PKT技术克服了现有蒸馏方法的一些局限性,通过匹配特征空间中数据的概率分布,从而实现知识蒸馏。

2.基于概率的知识蒸馏(PKT)

为了使得学生模型能够有效的学习教师模型的概率分布。作者在训练网络的过程中,对每个batch中的数据样本之间的成对交互进行建模,使得其可以描述相应特征空间的几何形状。利用特征空间中任意两个数据点的联合概率密度,对两个数据点之间的距离进行概率分布建模。通过最小化教师模型与学生模型的联合密度概率估计的差异,实现概率分布学习。
联合概率密度函数公式:

从上述公式可以发现,最小化概率分布并不需要用到标签数据,因此PKT甚至可以用到无监督学习中。利用上述所说的联合概率分布进行知识蒸馏可以避免很多传统蒸馏方法的缺点。但是,由于实际训练中我们每个batch都是所有数据的随机抽样,使用全局数据是不现实的,基于此作者使用样本的条件概率分布代替联合概率密度函数。
条件概率密度函数公式:

计算当前batch中数据两两之间的条件概率密度后,通过最小化教师模型的条件概率分布和学生模型的条件概率分布的KL散度,实现概率知识蒸馏。

3.计算概率分布

如上述所示的条件概率分布函数公式可知,要求数据间的条件概率分布需要定义对应的核函数。常见的核函数有高斯核,具体公式如下所示,但由于高斯核中需要定义一个超参数,且该超参数对最终蒸馏结果会参数极大的影响。因此本文并没有采用这种常见的核函数。

本文尝试通过余弦核函数进行条件概率估计。其公式如下所示,根据余弦函数的定义可以更好的解释本文提出的PKL蒸馏法体现出的架构和维度无关性。

def cosine_similarity_loss(output_net, target_net, eps=0.0000001):# Normalize each vector by its normoutput_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))output_net = output_net / (output_net_norm + eps)output_net[output_net != output_net] = 0target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))target_net = target_net / (target_net_norm + eps)target_net[target_net != target_net] = 0# Calculate the cosine similaritymodel_similarity = torch.mm(output_net, output_net.transpose(0, 1))target_similarity = torch.mm(target_net, target_net.transpose(0, 1))# Scale cosine similarity to 0..1model_similarity = (model_similarity + 1.0) / 2.0target_similarity = (target_similarity + 1.0) / 2.0# Transform them into probabilitiesmodel_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)# Calculate the KL-divergenceloss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))return loss

这段代码就是对上述公式的翻译,代码中的output_net代表了当前数据的学生模型输出特征图,而target_net代表了当前数据的教师模型输出特征图。正常情况下该特征图维度一般都为:NCHW。根据上述代码不论两者的C的维度是多少,有或者HW的维度是多少,最终经过矩阵转置相乘,都会变成一个N*N大小的相似性矩阵。通过相似性矩阵经过一系列计算,最终求得两者的概率分布,并进行概率学习。

4.结果展示

PKT基于概率的知识蒸馏应用到分类和目标检测任务中,从下表的结果可以看出该方法的通用和有效性。

ECCV2018 | PKT_概率知识蒸馏相关推荐

  1. 知识蒸馏(Knowledge Distillation)详细深入透彻理解重点

    知识蒸馏是一种模型压缩方法,是一种基于"教师-学生网络思想"的训练方法,由于其简单,有效,在工业界被广泛应用.这一技术的理论来自于2015年Hinton发表的一篇神作: 论文链接 ...

  2. 收藏 | 一文带你总览知识蒸馏,详解经典论文

    「免费学习 60+ 节公开课:投票页面,点击讲师头像」 作者:凉爽的安迪 来源 | 深度传送门(ID:deep_deliver) [导读]这是一篇关于[知识蒸馏]简述的文章,目的是想对自己对于知识蒸馏 ...

  3. 关于知识蒸馏,这三篇论文详解不可错过

    作者 | 孟让 转载自知乎 导语:继<从Hinton开山之作开始,谈知识蒸馏的最新进展>之后,作者对知识蒸馏相关重要进行了更加全面的总结.在上一篇文章中主要介绍了attention tra ...

  4. 知识蒸馏在推荐系统的应用

    点击上方,选择星标或置顶,每天给你送干货! 作者 | 张俊林 本文转载自知乎 https://zhuanlan.zhihu.com/p/143155437 随着深度学习的快速发展,优秀的模型层出不穷, ...

  5. 知识蒸馏:如何用一个神经网络训练另一个神经网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 如果你曾经用神经网络来解决一个复杂的问题,你就会知道它们的尺寸可能 ...

  6. 知识蒸馏综述:蒸馏机制

    作者丨pprp 来源丨GiantPandaCV 编辑丨极市平台 导读 这一篇介绍各个算法的蒸馏机制,根据教师网络是否和学生网络一起更新,可以分为离线蒸馏,在线蒸馏和自蒸馏. 感性上理解三种蒸馏方式: ...

  7. 万字总结,知识蒸馏如何在推荐系统中大显身手?

    来源:AI科技评论 作者 | 张俊杰 编辑 | 丛 末 本文首发于知乎 https://zhuanlan.zhihu.com/p/143155437 随着深度学习的快速发展,优秀的模型层出不穷,比如图 ...

  8. 训练softmax分类器实例_知识蒸馏:如何用一个神经网络训练另一个神经网络

    作者:Tivadar Danka 编译:ronghuaiyang 原文链接 知识蒸馏:如何用一个神经网络训练另一个神经网络​mp.weixin.qq.com 导读 知识蒸馏的简单介绍,让大家了解知识蒸 ...

  9. 浅谈“知识蒸馏”技术在机器学习领域的应用

    什么是知识蒸馏技术? 知识蒸馏技术首次出现是在Hinton几年前的一篇论文<Distilling the Knowledge in a Neural Network>.老大爷这么大岁数了还 ...

最新文章

  1. 分析与设计(AD)简介(3)
  2. 工信部企业信息核查 谋定“互联网+监管”经信研究创新实践
  3. Android7.1 Presentation双屏异显原理分析
  4. Windows Server 2012改造成Windows8的方法(转载)
  5. 最受欢迎的Java环境
  6. jstree 节点拖拽保存数据库
  7. 第13课 智商问题 《小学生C++趣味编程》
  8. mysql创建数据库时使用sql/wordbench使主键(primary key)自增
  9. 编写技术解决方案思路
  10. 【Cubase11】音乐工作站:宿主软件 - 基础入门笔记
  11. NeoKylin(linux)操作系统基本操作(自用)
  12. python实用例子
  13. CleanMyMacX4.11.3最新版mac电脑磁盘清理工具功能
  14. cogs339 维修数列 ……
  15. ES6语法总结(21)--Generator函数的异步应用
  16. db2嵌套查询效率_提高 DB2 查询性能的常用方法
  17. 仿写哔哩哔哩的头部导航部分(HTML+CSS静态)
  18. Boost PFC参数计算——PFC电感
  19. 《web结课作业的源码》中华传统文化题材网页设计主题——基于HTML+CSS+JavaScript精美自适应绿色茶叶公司(12页)
  20. 移植最新4.19.8内核至JZ2440——根文件系统制作

热门文章

  1. 【渝粤教育】国家开放大学2018年秋季 8634-22T (1)Android智能手机编程 参考试题
  2. kuka机器人offset指令_KUKA机器人MADA详解.doc
  3. 设计模式之依赖倒置设计原则
  4. 外贸建站之独立站系统选择
  5. 网络安全——数据链路层安全协议
  6. 通往Android的神奇之旅-刘桂林-专题视频课程
  7. java实现测量到的工程数据
  8. 项目:宅人食堂——点餐系统
  9. 连续系统的动态规划问题
  10. 搜狗搜索事业部总经理:从识图搜索谈未来大势