fastText的原理剖析

1. fastText的模型架构

fastText的架构非常简单,有三层:输入层、隐含层、输出层(Hierarchical Softmax)

输入层:是对文档embedding之后的向量,包含有N-garm特征

隐藏层:是对输入数据的求和平均

输出层:是文档对应标签

如下图所示:

1.1 N-garm的理解

1.1.1 bag of word

bag of word 又称为bow,称为词袋。是一种只统计词频的手段。   【通过词频代替原来词语,没有考虑词语的位置】

例如:在机器学习的课程中通过朴素贝叶斯来预测文本的类别,我们学习的countVectorizer和TfidfVectorizer都可以理解为一种bow模型。

1.1.2 N-gram模型     【使用连续的n个词作为一个特征】

但是在很多情况下,词袋模型是不满足我们的需求的。

例如:我爱她她爱我在词袋模型下面,概率完全相同,但是其含义确实差别非常大。

为了解决这个问题,就有了N-gram模型,它不仅考虑词频,还会考虑当前词前面的词语,比如我爱她爱

N-gram模型的描述是:第n个词出现与前n-1个词相关,而与其他任何词不相关。(当然在很多场景下和前n-1个词也会相关,但是为了简化问题,经常会这样去计算)

例如:I love deep learning这个句子,在n=2的情况下,可以表示为{i love},{love deep},{deep learning},n=3的情况下,可以表示为{I love deep},{love deep learning}

在n=2的情况下,这个模型被称为Bi-garm(二元n-garm模型)

在n=3 的情况下,这个模型被称为Tri-garm(三元n-garm模型)

具体可以参考 ed3book chapter3

所以在fasttext的输入层,不仅有分词之后的词语,还有包含有N-gram的组合词语一起作为输入

2. fastText中的层次化的softmax-对传统softmax的优化方法1

为了提高效率,在fastText中计算分类标签的概率的时候,不再是使用传统的softmax来进行多分类的计算,而是使用的哈夫曼树(Huffman,也成为霍夫曼树),使用层次化的softmax(Hierarchial softmax)来进行概率的计算。

2.1 哈夫曼树和哈夫曼编码

2.1.1 哈夫曼树的定义

哈夫曼树概念:给定n个权值作为n个叶子结点,构造一棵二叉树,若该树的带权路径长度达到最小,称这样的二叉树为最优二叉树,也称为哈夫曼树(Huffman Tree)。

哈夫曼树是带权路径长度最短的树,权值较大的结点离根较近。

2.1.2 哈夫曼树的相关概念

二叉树:每个节点最多有2个子树的有序树,两个子树分别称为左子树、右子树。有序的意思是:树有左右之分,不能颠倒

叶子节点:一棵树当中没有子结点的结点称为叶子结点,简称“叶子”

路径和路径长度:在一棵树中,从一个结点往下可以达到的孩子或孙子结点之间的通路,称为路径。通路中分支的数目称为路径长度。若规定根结点的层数为1,则从根结点到第L层结点的路径长度为L-1。

结点的权及带权路径长度:若将树中结点赋给一个有着某种含义的数值,则这个数值称为该结点的权。结点的带权路径长度为:从根结点到该结点之间的路径长度与该结点的权的乘积

树的带权路径长度:树的带权路径长度规定为所有叶子结点的带权路径长度之和

树的高度:树中结点的最大层次。包含n个结点的二叉树的高度至少为log2 (n+1)

2.1.3 哈夫曼树的构造算法

  1. W_1,W_2,W_3 \dots W_n看成n棵树的森林

  2. 在森林中选择两个根节点权值最小的树进行合并,作为一颗新树的左右子树,新树的根节点权值为左右子树的和

  3. 删除之前选择出的子树,把新树加入森林

  4. 重复2-3步骤,直到森林只有一棵树为止,概树就是所求的哈夫曼树

例如:圆圈中的表示每个词语出现的次数,以这些词语为叶子节点构造的哈夫曼树过程如下:

可见:

  1. 权重越大,距离根节点越近

  2. 叶子的个数为n,构造哈夫曼树中新增的节点的个数为n-1

2.2.1 哈夫曼编码

在数据通信中,需要将传送的文字转换成二进制的字符串,用0,1码的不同排列来表示字符。

例如,需传送的报文为AFTER DATA EAR ARE ART AREA,这里用到的字符集为A,E,R,T,F,D,各字母出现的次数为{8,4,5,3,1,1}。现要求为这些字母设计编码。要区别6个字母,最简单的二进制编码方式是等长编码,固定采用3位二进制,可分别用000、001、010、011、100、101A,E,R,T,F,D进行编码发送

但是很明显,上述的编码的方式并不是最优的,即整理传送的字节数量并不是最少的。

为了提高数据传送的效率,同时为了保证任一字符的编码都不是另一个字符编码的前缀,这种编码称为前缀编码[前缀编码],可以使用哈夫曼树生成哈夫曼编码解决问题

可用字符集中的每个字符作为叶子结点生成一棵编码二叉树,为了获得传送报文的最短长度,可将每个字符的出现频率作为字符结点的权值赋予该结点上,显然字使用频率越小权值越小,权值越小叶子就越靠下,于是频率小编码长,频率高编码短,这样就保证了此树的最小带权路径长度效果上就是传送报文的最短长度

因此,求传送报文的最短长度问题转化为求由字符集中的所有字符作为叶子结点,由字符出现频率作为其权值所产生的哈夫曼树的问题。利用哈夫曼树来设计二进制的前缀编码,既满足前缀编码的条件,又保证报文编码总长最短。

下图中label1 .... label6分别表示A,E,R,T,F,D

2.3 梯度计算

上图中,红色为哈夫曼编码,即label5的哈夫曼编码为1001,那么此时如何定义条件概率P(Label5|contex)呢?

以Label5为例,从根节点到Label5中间经历了4次分支,每次分支都可以认为是进行了一次2分类,根据哈夫曼编码,可以把路径中的每个非叶子节点0认为是负类,1认为是正类(也可以把0认为是正类)

由机器学习课程中逻辑回归使用sigmoid函数进行2分类的过程中,一个节点被分为正类的概率是\delta(X^{T}\theta) = \frac{1}{1+e^{-X^T\theta}},被分类负类的概率是:1-\delta(X^T\theta),其中\theta就是图中非叶子节点对应的参数\theta

对于从根节点出发,到达Label5一共经历4次2分类,将每次分类结果的概率写出来就是:

  1. 第一次:P(1|X,\theta_1) = \delta(X^T\theta_1) ,即从根节点到2 3节点的概率是在知道X和\theta_1的情况下取值为1的概率

  2. 第二次:P(0|X,\theta_2) =1- \delta(X^T\theta_2)

  3. 第三次:P(0 |X,\theta_3) =1- \delta(X^T\theta_4)

  4. 第四次:P(1|X,\theta_4) = \delta(X^T\theta_4)

但是我们需要求的是P(Label|contex), 他等于前4词的概率的乘积,公式如下(d_j^w​是第j个节点的哈夫曼编码)

P(Label|context) = \prod_{j=2}^5P(d_j|X,\theta_{j-1})

其中: P(d_j|X,\theta_{j-1}) = \left\{ \begin{aligned} &\delta(X^T\theta_{j-1}), & d_j=1;\\ &1-\delta(X^T\theta_{j-1}) & d_j=0; \end{aligned} \right.

有了损失函数之后,接下来就是对其中的X,\theta进行求导,并更新,最终还需要更新最开始的每个词语词向量

层次化softmax的好处:传统的softmax的时间复杂度为L(Labels的数量),但是使用层次化softmax之后时间复杂度的log(L) (二叉树高度和宽度的近似),从而在多分类的场景提高了效率

3. fastText中的negative sampling(负采样)-对传统softmax的优化方法2

negative sampling,即每次从除当前label外的其他label中,随机的选择几个作为负样本。具体的采样方法:

如果所有的label为V​,那么我们就将一段长度为1的线段分成V​份,每份对应所有label中的一类label。当然每个词对应的线段长度是不一样的,高频label对应的线段长,低频label对应的线段短。每个label的线段长度由下式决定:

在采样前,我们将这段长度为1的线段划分成$M​$等份,这里$M>>V​$,这样可以保证每个label对应的线段都会划分成对应的小块。而M份中的每一份都会落在某一个label对应的线段上。在采样的时候,我们只需要从$M​$个位置中采样出neg个位置就行,此时采样到的每一个位置对应到的线段所属的词就是我们的负例。

简单的理解就是,从原来所有的样本中,等比例的选择neg个负样本作(遇到自己则跳过),作为训练样本,添加到训练数据中,和正例样本一起来进行训练。

Negative Sampling也是采用了二元逻辑回归来求解模型参数,通过负采样,我们得到了neg个负例,将正例定义为$label_0​$,负例定义为$label_i,i=1,2,3...neg​$

定义正例的概率为P\left( label{0}|\text {context}\right)=\sigma\left(x{\mathrm{k}}^{T} \theta\right), y_{i}=1

则负例的概率为:P\left( label{i}|\text {context}\right)=1-\sigma\left(x{\mathrm{k}}^{T} \theta\right), y_{i}=0,i=1,2,3..neg

此时对应的对数似然函数为:

之后会使用梯度上升的方法进行梯度计算和参数更新,仅仅每次只用一波样本(一个正例和neg个反例)更新梯度,来进行迭代更新

具体的更新伪代码如下:

其中内部大括号部分为w相关参数的梯度计算过程,e为w的梯度和学习率的乘积,具体参考:https://blog.csdn.net/itplus/article/details/37998797

好处:

  1. 提高训练速度,选择了部分数据进行计算损失,同时整个对每一个label而言都是一个二分类,损失计算更加简单,只需要让当前label的值的概率尽可能大,其他label的都为反例,概率会尽可能小

  2. 改进效果,增加部分负样本,能够模拟真实场景下的噪声情况,能够让模型的稳健性更强

fastText的原理剖析相关推荐

  1. socket之send和recv原理剖析

    socket之send和recv原理剖析 1. 认识TCP socket的发送和接收缓冲区 当创建一个TCP socket对象的时候会有一个发送缓冲区和一个接收缓冲区,这个发送和接收缓冲区指的就是内存 ...

  2. lua游戏脚本实例源码_Lua与其他宿主语言交互原理剖析

    Lua与其他宿主语言交互原理剖析 题外话:今天周末,刚好在家有时间就把我这次项目组内部分享的文章贴出来,分享给大家,同时也方便以后自己翻阅. 一. Lua简介 目标:Lua语言本身是用C语言来编写开发 ...

  3. Go语言底层原理剖析

    作者:郑建勋 出版社:电子工业出版社 品牌:博文视点 出版时间:2021-08-01 Go语言底层原理剖析

  4. 彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进

    视觉三维重建 = 定位定姿 + 稠密重建 + surface reconstruction +纹理贴图.三维重建技术是计算机视觉的重要技术之一,基于视觉的三维重建技术通过深度数据获取.预处理.点云配准 ...

  5. Elasticsearch分布式一致性原理剖析(一)-节点篇

    2019独角兽企业重金招聘Python工程师标准>>> 摘要: ES目前是最流行的开源分布式搜索引擎系统,其使用Lucene作为单机存储引擎并提供强大的搜索查询能力.学习其搜索原理, ...

  6. java 反序列化 ysoserial exploit/JRMPListener 原理剖析

    目录 0 前言 1 payloads/JRMPClient 1.1 Externalizable 1.2 生成payload 1.3 gadget链分析 2 exploit/JRMPListener ...

  7. 统计学习方法|支持向量机(SVM)原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

  8. 统计学习方法|逻辑斯蒂原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

  9. 统计学习方法|朴素贝叶斯原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

最新文章

  1. 怎么这一个c语言的dll文件,如何在C中获取DLL文件的版本信息
  2. 71 Zabbix自定义触发器
  3. Windows保护模式学习笔记(一)—— 段寄存器GDT表
  4. ABAP数据库操作系列之操作语句讲解Select
  5. 前端学习(2450):页面布局制作
  6. mysql中locat函数,MySQL中的LOCATE和POSITION函数使用方法 | 很文博客
  7. php代码练习,PHP模拟测试练习
  8. xcode 插件安装路径
  9. linux mp4v2编译,Android 编译mp4 v2 2.0.0生成动态库
  10. 目前最快的 Java 框架居然是它?真的最快,秒射~
  11. sql string转换成int型 sql截取字符串
  12. SSRS 2012 高级图表类型 -- 圆饼图
  13. 红色警戒2:尤里的复仇 中文绿色版
  14. EGO1—实现计数器74HC163
  15. linux分配设备编号
  16. 【ubuntu】解决 Certificate verification failed: The certificate is NOT trusted
  17. vim-python怎么用_技术|如何在使用 Vim 时访问/查看 Python 帮助
  18. 7-36 大炮打蚊子(15 分)
  19. Android Studio 美化
  20. D咚买菜抢购autojs核心代码分享

热门文章

  1. 分析6千万条GitHub帖子,发现你的工作状态与表情符号强相关
  2. 科研费4年翻3倍,全球科研队伍突破8000人,滴滴致力打造出行领域核心技术
  3. 被“钱”困住的开源开发者们!
  4. Python 炫技操作:合并字典的七种方法
  5. 三年、四大顶会,深度推荐系统18篇论文只有7个可以复现
  6. 免费GPU哪家强?谷歌Kaggle vs. Colab
  7. 报名 | 美团是怎样给你推荐外卖的?美团大脑知识图谱详解
  8. 吊打 ThreadLocal,谈谈FastThreadLocal为啥能这么快?
  9. 从零搭建 Spring Cloud 服务(超级详细)
  10. SpringBoot 项目瘦身指南,瘦到不可思议!