10.2 近似训练

回忆上一节的内容。跳字模型的核心在于使用softmax运算得到给定中心词wcw_cwc​来生成背景词wow_owo​的条件概率

P(wo∣wc)=exp(uo⊤vc)∑i∈Vexp(ui⊤vc).P(w_o \mid w_c) = \frac{\text{exp}(\boldsymbol{u}_o^\top \boldsymbol{v}_c)}{ \sum_{i \in \mathcal{V}} \text{exp}(\boldsymbol{u}_i^\top \boldsymbol{v}_c)}.P(wo​∣wc​)=∑i∈V​exp(ui⊤​vc​)exp(uo⊤​vc​)​.

该条件概率相应的对数损失

−log⁡P(wo∣wc)=−uo⊤vc+log⁡(∑i∈Vexp(ui⊤vc)).-\log P(w_o \mid w_c) = -\boldsymbol{u}_o^\top \boldsymbol{v}_c + \log\left(\sum_{i \in \mathcal{V}} \text{exp}(\boldsymbol{u}_i^\top \boldsymbol{v}_c)\right).−logP(wo​∣wc​)=−uo⊤​vc​+log(i∈V∑​exp(ui⊤​vc​)).

由于softmax运算考虑了背景词可能是词典V\mathcal{V}V中的任一词,以上损失包含了词典大小数目的项的累加。在上一节中我们看到,不论是跳字模型还是连续词袋模型,由于条件概率使用了softmax运算,每一步的梯度计算都包含词典大小数目的项的累加。对于含几十万或上百万词的较大词典,每次的梯度计算开销可能过大。为了降低该计算复杂度,本节将介绍两种近似训练方法,即负采样(negative sampling)或层序softmax(hierarchical softmax)。由于跳字模型和连续词袋模型类似,本节仅以跳字模型为例介绍这两种方法。

10.2.1 负采样

负采样修改了原来的目标函数。给定中心词wcw_cwc​的一个背景窗口,我们把背景词wow_owo​出现在该背景窗口看作一个事件,并将该事件的概率计算为

P(D=1∣wc,wo)=σ(uo⊤vc),P(D=1\mid w_c, w_o) = \sigma(\boldsymbol{u}_o^\top \boldsymbol{v}_c),P(D=1∣wc​,wo​)=σ(uo⊤​vc​),

其中的σ\sigmaσ函数与sigmoid激活函数的定义相同:

σ(x)=11+exp⁡(−x).\sigma(x) = \frac{1}{1+\exp(-x)}.σ(x)=1+exp(−x)1​.

我们先考虑最大化文本序列中所有该事件的联合概率来训练词向量。具体来说,给定一个长度为TTT的文本序列,设时间步ttt的词为w(t)w^{(t)}w(t)且背景窗口大小为mmm,考虑最大化联合概率

∏t=1T∏−m≤j≤m,j≠0P(D=1∣w(t),w(t+j)).\prod_{t=1}^{T} \prod_{-m \leq j \leq m,\ j \neq 0} P(D=1\mid w^{(t)}, w^{(t+j)}).t=1∏T​−m≤j≤m, j​=0∏​P(D=1∣w(t),w(t+j)).

然而,以上模型中包含的事件仅考虑了正类样本。这导致当所有词向量相等且值为无穷大时,以上的联合概率才被最大化为1。很明显,这样的词向量毫无意义。负采样通过采样并添加负类样本使目标函数更有意义。设背景词wow_owo​出现在中心词wcw_cwc​的一个背景窗口为事件PPP,我们根据分布P(w)P(w)P(w)采样KKK个未出现在该背景窗口中的词,即噪声词。设噪声词wkw_kwk​(k=1,…,Kk=1, \ldots, Kk=1,…,K)不出现在中心词wcw_cwc​的该背景窗口为事件NkN_kNk​。假设同时含有正类样本和负类样本的事件P,N1,…,NKP, N_1, \ldots, N_KP,N1​,…,NK​相互独立,负采样将以上需要最大化的仅考虑正类样本的联合概率改写为

∏t=1T∏−m≤j≤m,j≠0P(w(t+j)∣w(t)),\prod_{t=1}^{T} \prod_{-m \leq j \leq m,\ j \neq 0} P(w^{(t+j)} \mid w^{(t)}),t=1∏T​−m≤j≤m, j​=0∏​P(w(t+j)∣w(t)),

其中条件概率被近似表示为
P(w(t+j)∣w(t))=P(D=1∣w(t),w(t+j))∏k=1,wk∼P(w)KP(D=0∣w(t),wk).P(w^{(t+j)} \mid w^{(t)}) =P(D=1\mid w^{(t)}, w^{(t+j)})\prod_{k=1,\ w_k \sim P(w)}^K P(D=0\mid w^{(t)}, w_k).P(w(t+j)∣w(t))=P(D=1∣w(t),w(t+j))k=1, wk​∼P(w)∏K​P(D=0∣w(t),wk​).

设文本序列中时间步ttt的词w(t)w^{(t)}w(t)在词典中的索引为iti_tit​,噪声词wkw_kwk​在词典中的索引为hkh_khk​。有关以上条件概率的对数损失为

−log⁡P(w(t+j)∣w(t))=−log⁡P(D=1∣w(t),w(t+j))−∑k=1,wk∼P(w)Klog⁡P(D=0∣w(t),wk)=−log⁡σ(uit+j⊤vit)−∑k=1,wk∼P(w)Klog⁡(1−σ(uhk⊤vit))=−log⁡σ(uit+j⊤vit)−∑k=1,wk∼P(w)Klog⁡σ(−uhk⊤vit).\begin{aligned} -\log P(w^{(t+j)} \mid w^{(t)}) =& -\log P(D=1\mid w^{(t)}, w^{(t+j)}) - \sum_{k=1,\ w_k \sim P(w)}^K \log P(D=0\mid w^{(t)}, w_k)\\ =&- \log\, \sigma\left(\boldsymbol{u}_{i_{t+j}}^\top \boldsymbol{v}_{i_t}\right) - \sum_{k=1,\ w_k \sim P(w)}^K \log\left(1-\sigma\left(\boldsymbol{u}_{h_k}^\top \boldsymbol{v}_{i_t}\right)\right)\\ =&- \log\, \sigma\left(\boldsymbol{u}_{i_{t+j}}^\top \boldsymbol{v}_{i_t}\right) - \sum_{k=1,\ w_k \sim P(w)}^K \log\sigma\left(-\boldsymbol{u}_{h_k}^\top \boldsymbol{v}_{i_t}\right). \end{aligned} −logP(w(t+j)∣w(t))===​−logP(D=1∣w(t),w(t+j))−k=1, wk​∼P(w)∑K​logP(D=0∣w(t),wk​)−logσ(uit+j​⊤​vit​​)−k=1, wk​∼P(w)∑K​log(1−σ(uhk​⊤​vit​​))−logσ(uit+j​⊤​vit​​)−k=1, wk​∼P(w)∑K​logσ(−uhk​⊤​vit​​).​

现在,训练中每一步的梯度计算开销不再与词典大小相关,而与KKK线性相关。当KKK取较小的常数时,负采样在每一步的梯度计算开销较小。

10.2.2 层序softmax

层序softmax是另一种近似训练法。它使用了二叉树这一数据结构,树的每个叶结点代表词典V\mathcal{V}V中的每个词。

图10.3 层序softmax。二叉树的每个叶结点代表着词典的每个词

假设L(w)L(w)L(w)为从二叉树的根结点到词www的叶结点的路径(包括根结点和叶结点)上的结点数。设n(w,j)n(w,j)n(w,j)为该路径上第jjj个结点,并设该结点的背景词向量为un(w,j)\boldsymbol{u}_{n(w,j)}un(w,j)​。以图10.3为例,L(w3)=4L(w_3) = 4L(w3​)=4。层序softmax将跳字模型中的条件概率近似表示为

P(wo∣wc)=∏j=1L(wo)−1σ([⁣[n(wo,j+1)=leftChild(n(wo,j))]⁣]⋅un(wo,j)⊤vc),P(w_o \mid w_c) = \prod_{j=1}^{L(w_o)-1} \sigma\left( [\![ n(w_o, j+1) = \text{leftChild}(n(w_o,j)) ]\!] \cdot \boldsymbol{u}_{n(w_o,j)}^\top \boldsymbol{v}_c\right),P(wo​∣wc​)=j=1∏L(wo​)−1​σ([[n(wo​,j+1)=leftChild(n(wo​,j))]]⋅un(wo​,j)⊤​vc​),

其中σ\sigmaσ函数与3.8节(多层感知机)中sigmoid激活函数的定义相同,leftChild(n)\text{leftChild}(n)leftChild(n)是结点nnn的左子结点:如果判断xxx为真,[⁣[x]⁣]=1[\![x]\!] = 1[[x]]=1;反之[⁣[x]⁣]=−1[\![x]\!] = -1[[x]]=−1。
让我们计算图10.3中给定词wcw_cwc​生成词w3w_3w3​的条件概率。我们需要将wcw_cwc​的词向量vc\boldsymbol{v}_cvc​和根结点到w3w_3w3​路径上的非叶结点向量一一求内积。由于在二叉树中由根结点到叶结点w3w_3w3​的路径上需要向左、向右再向左地遍历(图10.3中加粗的路径),我们得到

P(w3∣wc)=σ(un(w3,1)⊤vc)⋅σ(−un(w3,2)⊤vc)⋅σ(un(w3,3)⊤vc).P(w_3 \mid w_c) = \sigma(\boldsymbol{u}_{n(w_3,1)}^\top \boldsymbol{v}_c) \cdot \sigma(-\boldsymbol{u}_{n(w_3,2)}^\top \boldsymbol{v}_c) \cdot \sigma(\boldsymbol{u}_{n(w_3,3)}^\top \boldsymbol{v}_c).P(w3​∣wc​)=σ(un(w3​,1)⊤​vc​)⋅σ(−un(w3​,2)⊤​vc​)⋅σ(un(w3​,3)⊤​vc​).

由于σ(x)+σ(−x)=1\sigma(x)+\sigma(-x) = 1σ(x)+σ(−x)=1,给定中心词wcw_cwc​生成词典V\mathcal{V}V中任一词的条件概率之和为1这一条件也将满足:

∑w∈VP(w∣wc)=1.\sum_{w \in \mathcal{V}} P(w \mid w_c) = 1.w∈V∑​P(w∣wc​)=1.

此外,由于L(wo)−1L(w_o)-1L(wo​)−1的数量级为O(log2∣V∣)\mathcal{O}(\text{log}_2|\mathcal{V}|)O(log2​∣V∣),当词典V\mathcal{V}V很大时,层序softmax在训练中每一步的梯度计算开销相较未使用近似训练时大幅降低。

小结

  • 负采样通过考虑同时含有正类样本和负类样本的相互独立事件来构造损失函数。其训练中每一步的梯度计算开销与采样的噪声词的个数线性相关。
  • 层序softmax使用了二叉树,并根据根结点到叶结点的路径来构造损失函数。其训练中每一步的梯度计算开销与词典大小的对数相关。

注:本节与原书完全相同,原书传送门

10.2_approx-training相关推荐

  1. Google TensorFlow课程 编程笔记(10)———使用神经网络对手写数字进行分类

    使用神经网络对手写数字进行分类 学习目标: 训练线性模型和神经网络,以对传统 MNIST 数据集中的手写数字进行分类 比较线性分类模型和神经网络分类模型的效果 可视化神经网络隐藏层的权重 我们的目标是 ...

  2. CVPR 2023 接收结果出炉!再创历史新高!录用2360篇!(附10篇最新论文)

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 点击进入->[计算机视觉]微信技术交流群 2023 年 2 月 28 日凌晨,CVPR 2023 顶会 ...

  3. Deformable 可变形的DETR

    Deformable 可变形的DETR This repository is an official implementation of the paper Deformable DETR: Defo ...

  4. 为什么要进行数据归一化

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 原文:https://medium.com/@urvashillu ...

  5. 7Papers|斯坦福学者造出机器鸽;港科大等提出学生情绪分析新系统

    机器之心&ArXiv Weekly Radiostation 参与:杜伟,楚航,罗若天 本周既有港科大.哈工程等机构提出的观察课堂学生情绪变化.注意力集中程度的 EmotionCues 系统, ...

  6. Spark MLlib 机器学习

    本章导读 机器学习(machine learning, ML)是一门涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多领域的交叉学科.ML专注于研究计算机模拟或实现人类的学习行为,以获取新知识.新 ...

  7. python sklearn.neural_network.MLPClassifier() 神经网络改变模型复杂度的四种方法

    MLPClassifier() 改变模型复杂度的四种方法 调整神经网络每一个隐藏层上的节点数 调节神经网络隐藏层的层数 调节activation的方式 通过调整alpha值来改变模型正则化的程度(增大 ...

  8. 个人阅读的Deep Learning方向的paper整理

    http://hi.baidu.com/chb_seaok/item/6307c0d0363170e73cc2cb65 个人阅读的Deep Learning方向的paper整理,分了几部分吧,但有些部 ...

  9. 机器学习实战:GBDT Xgboost LightGBM对比

    Mnist数据集识别 使用Sklearn的GBDT GradientBoostingClassifier GradientBoostingRegressor import gzip import pi ...

  10. Spark MLlib回归算法------线性回归、逻辑回归、SVM和ALS

    Spark MLlib回归算法------线性回归.逻辑回归.SVM和ALS 1.线性回归: (1)模型的建立: 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多 ...

最新文章

  1. 在ubuntu12.04上使用华为et127 3g上网卡
  2. mongodb 安装pymongo 驱动
  3. 2003白金一代NBA选秀
  4. [云炬创业管理笔记]第三章打造优秀创业团队讨论3
  5. spring声明事务与编程事务概述
  6. sphinx4 FrontEnd流程分析
  7. activiti页面展示流程图乱码_activiti 5.17 流程图中文乱码问题
  8. django使用mysql_设置Django以使用MySQL
  9. 当年叱咤风云的框架Struts2,你可知Struts2内功如何修炼
  10. 那些在一个公司死磕了5-10年的测试员,最后都怎么样了?
  11. 详细解析Photoshop10个必学的抠图技巧
  12. 【深入浅出imx8企业级开发实战 | 01】imx8qxp yocto工程构建指南
  13. CPU测评程序、指标、工具
  14. 利用模版元编程将传统冒泡排序性能提升两倍以上
  15. 【mean teacher】RuntimeError: Integer division of tensors using div or / is no longer suppor的解决
  16. 马蹄疾 | 2019年,是时候认真学一波 Grid 布局了
  17. Spark多行合并一行collect_list使用
  18. 谈谈登录注册的如何实现
  19. 大数据和机器学习,对我们商业和生活的影响
  20. ios 模拟器沙盒_查看iOS模拟器应用的沙箱文件

热门文章

  1. Windows XP操作系统自带工具应用
  2. 目标跟踪 KCF(High-Speed Tracking with Kernelized Correlation Filters)
  3. 吴恩达.深度学习系列-C1神经网络与深度学习-W1介绍
  4. 华为、中兴对比 [zz]
  5. Junit —— 单元测试工具基本使用
  6. A. AD 2020
  7. jdk eclipse SDK下载安装及配置教程
  8. 计算机出国转专业吗,美国留学可以转专业申请计算机专业吗?
  9. Python基于OpenCV的固定位置半透明水印去除方案
  10. 育润又添产品线,以差异化优势进军羊奶粉市场