文章目录

  • 前言
  • 为什么要用projection layer
  • 传统LSTM
  • LSTM projection layer 结构(LSTMP)
  • 参考网址

其实在实习之前对于一些知识点的理解还是欠缺的,很多时候感觉没什么用的基础知识为什么还会在面试的时候被问到,比如,“你画一下GRU的基本结构”,你在做工程的时候回修改GRU的结构吗?居然会问这种问题哎,或许你现在也是这么想的~
但是当真正的实习的时候你会发现公司里面的数据集是需要你自己挖掘的,大公司里面的数据时很廉价的,举个例子,我之前处理的用户数据都是好几个T的,xxx亿的数据做处理,最后留下来的数据就能到5亿,这是我在实验室从来没有接触过的,几百万的数据应该就算奢侈了,但是在公司里数据可能相当的廉价。有长必有短,平时处理模型的时候你可能很少关注性能,但是公司里就不一样了,要考虑集群吃不吃得消,要考虑在多少ns ms返回结果,这样你就不得不考虑模型的大小了,另外,至少我接触到的某搜索引擎的模型都是相对比较简单的,跟我在比赛时候用的模型比起来简直是“low”到爆,为了寻求速度和效率上的提升,在基础的LSTM或者GRU中的cell上寻求一点trick也就情理之中了~
所以下次面试让你画GRU的时候,你别有什么负面情绪了,这写真的用!得!到!~
扯远了,进入今天的整体LSTMP~ show time~~~

前言

首先最简单的LSTM结构可以详见我之前的帖子(GRU和LSTM总结),但是在之后撸代码的时候我发现一个经典的LSTM结构是这样的
i t = δ ( W i x x t + W i m m t − 1 + W i c c t − 1 + b i ) i_t=\delta(W_{ix}x_t+W_{im}m_{t-1}+W_{ic}c_{t-1}+b_i) it​=δ(Wix​xt​+Wim​mt−1​+Wic​ct−1​+bi​)

f t = δ ( W f x x t + W f m m t − 1 + W f c c t − 1 + b i ) f_t=\delta(W_{fx}x_t+W_{fm}m_{t-1}+W_{fc}c_{t-1}+b_i) ft​=δ(Wfx​xt​+Wfm​mt−1​+Wfc​ct−1​+bi​)

c t = f t ⊙ c t − 1 + i t ⊙ g ( W c x x t + W c m m t − 1 + b c ) c_t=f_t\odot c_{t-1}+i_t\odot g(W_{cx}x_t+W_{cm}m_{t-1}+b_c) ct​=ft​⊙ct−1​+it​⊙g(Wcx​xt​+Wcm​mt−1​+bc​)

o t = δ ( W o x x t + W o m m t − 1 + W o c c t + b o ) o_t=\delta(W_{ox}x_t+W_{om}m_{t-1}+W_{oc}c_{t}+b_o) ot​=δ(Wox​xt​+Wom​mt−1​+Woc​ct​+bo​)

m t = o t ⊙ h ( c t ) m_t=o_t\odot h(c_t) mt​=ot​⊙h(ct​)

y t = ϕ ( W y m m t + b y ) y_t=\phi (W_{ym}m_t+b_y) yt​=ϕ(Wym​mt​+by​)

where the W terms denote weight matrices (e.g. Wix is the matrix of weights from the input gate to the input), W i c W_{ic} Wic​, W f c W_{fc} Wfc​, W o c W_{oc} Woc​ are diagonal weight matrices for peephole connections, the b terms denote bias vectors (bi is the input gate bias vector), σ is
the logistic sigmoid function, and i, f, o and c are respectively the input gate, forget gate, output gate and cell activation vectors, all of which are the same size as the cell output activation vector m, is the element-wise product of the vectors, g and h are the cell input and cell output activation functions, generally and in this paper tanh, and φ is the network output activation function, softmax in this paper.
—from Long Short-Term Memory Recurrent Neural Network Architectures
for Large Scale Acoustic Modeling
其实最原始的paper是long short-term memory based recurrent neural network architectures for large vocabulary speech recognition,但是你会发现在介绍参数的时候还没有上一篇讲的清楚。这里注意下 W i c W_{ic} Wic​, W f c W_{fc} Wfc​, W o c W_{oc} Woc​都是对角矩阵!!!

细心的读者肯定能发现这个表达形式和我之前介绍LSTM博客在表达形式和结构上都有些区别,对于表达形式都是换汤不换药,但是这里结构上也发生了一定的变化,增加了几项: + W i c c t − 1 、 + W f c c t − 1 、 + W o c c t +W_{ic}c_{t-1} 、+W_{fc}c_{t-1} 、 +W_{oc}c_{t} +Wic​ct−1​、+Wfc​ct−1​、+Woc​ct​
根据这个变化,我们再来回忆一下两个结构的不同:

基础版

升级版(peephole connections)

例如:即便是LSTM也有很多个变种。一个变种方式是调控门的输入。例如下面两种gate:
g = s i g m o i d ( W x g ⋅ x t + W h g ⋅ h t − 1 + b ) g= sigmoid(W_{xg} \cdot x_t + W_{hg} \cdot h_{t-1} + {b}) g=sigmoid(Wxg​⋅xt​+Whg​⋅ht−1​+b):
这种gate的输入有当前的输入 x t x_t xt​和上一时刻的隐藏状态 h t − 1 h_{t-1} ht−1​, 表示gate是将这两个信息流作为控制依据而产生输出的。
g = s i g m o i d ( W x g ⋅ x t + W h g ⋅ h t − 1 + W c g ⋅ c t − 1 + b ) g= sigmoid(W_{xg} \cdot x_t + W_{hg} \cdot h_{t-1} +W_{cg} \cdot c_{t-1}+ {b}) g=sigmoid(Wxg​⋅xt​+Whg​⋅ht−1​+Wcg​⋅ct−1​+b):
这种gate的输入有当前的输入 x t x_t xt​ 和上一时刻的隐藏状态 h t − 1 h_{t-1} ht−1​,以及上一时刻的cell状态 c t − 1 c_{t-1} ct−1​, 表示gate是将这三个信息流作为控制依据而产生输出的。这种方式的LSTM叫做peephole connections

上面两幅图很清晰,就是在用细胞状态(就是图里面的 C t C t − 1 C_t C_{t-1} Ct​Ct−1​,也就是最上面的额那条信息流)的时候有些区别。(leader说第二种的这种结构效果会比之前博客中写到的简单LSTM效果要好,我并没有试过。。。)

在下面的讲解中用到的就是第二种这个,来对比参数等一些区别。

为什么要用projection layer

首先在LSTM中的Projection layer是为了减少计算量的,它的作用和全连接layer很像,就是对输出向量做一下压缩,从而能把高纬度的信息降维,减小cell unit的维度,从而减小相关参数矩阵的参数数目!
一个很好的解释,What is the meaning of ‘projection layer’ in lstm?

传统LSTM

如上面所列出的一样,传统的LSTM的结构为:
i t = δ ( W i x x t + W i m m t − 1 + W i c c t − 1 + b i ) i_t=\delta(W_{ix}x_t+W_{im}m_{t-1}+W_{ic}c_{t-1}+b_i) it​=δ(Wix​xt​+Wim​mt−1​+Wic​ct−1​+bi​)

f t = δ ( W f x x t + W f m m t − 1 + W f c c t − 1 + b i ) f_t=\delta(W_{fx}x_t+W_{fm}m_{t-1}+W_{fc}c_{t-1}+b_i) ft​=δ(Wfx​xt​+Wfm​mt−1​+Wfc​ct−1​+bi​)

c t = f t ⊙ c t − 1 + i t ⊙ g ( W c x x t + W c m m t − 1 + b c ) c_t=f_t\odot c_{t-1}+i_t\odot g(W_{cx}x_t+W_{cm}m_{t-1}+b_c) ct​=ft​⊙ct−1​+it​⊙g(Wcx​xt​+Wcm​mt−1​+bc​)

o t = δ ( W o x x t + W o m m t − 1 + W o c c t + b o ) o_t=\delta(W_{ox}x_t+W_{om}m_{t-1}+W_{oc}c_{t}+b_o) ot​=δ(Wox​xt​+Wom​mt−1​+Woc​ct​+bo​)

m t = o t ⊙ h ( c t ) m_t=o_t\odot h(c_t) mt​=ot​⊙h(ct​)

y t = ϕ ( W y m m t + b y ) y_t=\phi (W_{ym}m_t+b_y) yt​=ϕ(Wym​mt​+by​)
最后的 y t y_t yt​是输出,公式里面的所有 m t m_t mt​表示的是图中的 h t h_t ht​,其他所有的m变成h即可,只不过写法不同而已!
那么如果不计算里面的bias(也就是 b i b c b o b_i b_c b_o bi​bc​bo​),那么最终的参数数目是:
W = n c ∗ n c ∗ 4 + n i ∗ n c ∗ 4 + n c ∗ n o + n c ∗ 3 W=n_c*n_c*4+n_i*n_c*4+n_c*n_o+n_c*3 W=nc​∗nc​∗4+ni​∗nc​∗4+nc​∗no​+nc​∗3
其中 n c n_c nc​表示cell units的大小,也就是隐层的维度, n i n_i ni​是当前输入向量的维度, n o n_o no​表示的是最终 m t m_t mt​得到的最终输出y的维度
n c ∗ n c ∗ 4 n_c*n_c*4 nc​∗nc​∗4表示的是 W i m W_{im} Wim​、 W f m W_{fm} Wfm​、 W c m W_{cm} Wcm​、 W o m W_{om} Wom​的参数个数
n i ∗ n c ∗ 4 n_i*n_c*4 ni​∗nc​∗4表示的是 W i x W_{ix} Wix​、 W f x W_{fx} Wfx​、 W c x W_{cx} Wcx​、 W o x W_{ox} Wox​的参数个数
n c ∗ n o n_c*n_o nc​∗no​表示的是 W y m W_{ym} Wym​输出的时候的全连接层的参数个数
n c ∗ 3 n_c*3 nc​∗3表示的是 W i c W_{ic} Wic​、 W f c W_{fc} Wfc​、 W o c W_{oc} Woc​这几个对角矩阵的参数个数

LSTM projection layer 结构(LSTMP)

绿色的地方就是projection的地方,也就是在传递时 r t r_t rt​,也就是cell里面的 h t h_t ht​都换成了 r t r_t rt​,因此如果 r t r_t rt​的维度如果比 h t h_t ht​小的话,里面的整个运算的变量就小了,速度也就快了。

而加入projection layer改进后的公式结构如下:
i t = δ ( W i x x t + W i r r t − 1 + W i c c t − 1 + b i ) i_t=\delta(W_{ix}x_t+W_{ir}r_{t-1}+W_{ic}c_{t-1}+b_i) it​=δ(Wix​xt​+Wir​rt−1​+Wic​ct−1​+bi​)

f t = δ ( W f x x t + W f r r t − 1 + W f c c t − 1 + b i ) f_t=\delta(W_{fx}x_t+W_{fr}r_{t-1}+W_{fc}c_{t-1}+b_i) ft​=δ(Wfx​xt​+Wfr​rt−1​+Wfc​ct−1​+bi​)

c t = f t ⊙ c t − 1 + i t ⊙ g ( W c x x t + W c r r t − 1 + b c ) c_t=f_t\odot c_{t-1}+i_t\odot g(W_{cx}x_t+W_{cr}r_{t-1}+b_c) ct​=ft​⊙ct−1​+it​⊙g(Wcx​xt​+Wcr​rt−1​+bc​)

o t = δ ( W o x x t + W o r r t − 1 + W o c c t + b o ) o_t=\delta(W_{ox}x_t+W_{or}r_{t-1}+W_{oc}c_{t}+b_o) ot​=δ(Wox​xt​+Wor​rt−1​+Woc​ct​+bo​)

m t = o t ⊙ h ( c t ) m_t=o_t\odot h(c_t) mt​=ot​⊙h(ct​)

r t = W r m m t r_t=W_{rm}m_t rt​=Wrm​mt​

y t = ϕ ( W y r r t + b y ) y_t=\phi (W_{yr}r_t+b_y) yt​=ϕ(Wyr​rt​+by​)
这里的参数个数发生了变化,参数为:
W = n c ∗ n r ∗ 4 + n i ∗ n c ∗ 4 + n r ∗ n o + n c ∗ n r + n c ∗ 3 W=n_c*n_r*4+n_i*n_c*4+n_r*n_o+n_c*n_r+n_c*3 W=nc​∗nr​∗4+ni​∗nc​∗4+nr​∗no​+nc​∗nr​+nc​∗3
其中 n c n_c nc​表示cell units的大小,也就是隐层的维度, n i n_i ni​是当前输入向量的维度, n o n_o no​表示的是最终 m t m_t mt​得到的最终输出y的维度, n r n_r nr​表示的是projection layer的输出维度。
n c ∗ n r ∗ 4 n_c*n_r*4 nc​∗nr​∗4表示的是 W i m W_{im} Wim​、 W f m W_{fm} Wfm​、 W c m W_{cm} Wcm​、 W o m W_{om} Wom​的参数个数
n i ∗ n c ∗ 4 n_i*n_c*4 ni​∗nc​∗4表示的是 W i x W_{ix} Wix​、 W f x W_{fx} Wfx​、 W c x W_{cx} Wcx​、 W o x W_{ox} Wox​的参数个数
n c ∗ n o n_c*n_o nc​∗no​表示的是 W y m W_{ym} Wym​输出的时候的全连接层的参数个数
n c ∗ n r n_c*n_r nc​∗nr​表示的是projection layer的参数矩阵
n c ∗ 3 n_c*3 nc​∗3表示的是 W i c W_{ic} Wic​、 W f c W_{fc} Wfc​、 W o c W_{oc} Woc​这几个对角矩阵的参数个数

所以最终在举个例子,假设我们的cell units的大小是256,输入的维数是100,输出的维数是30,那么传统的LSTM参数个数是:
W = n c ∗ n c ∗ 4 + n i ∗ n c ∗ 4 + n c ∗ n o + n c ∗ 3 = 256 ∗ 256 ∗ 4 + 100 ∗ 256 ∗ 4 + 256 ∗ 30 + 256 ∗ 3 = 256 ∗ 1457 W=n_c*n_c*4+n_i*n_c*4+n_c*n_o+n_c*3 = 256*256*4+100*256*4+256*30+256*3 = 256*1457 W=nc​∗nc​∗4+ni​∗nc​∗4+nc​∗no​+nc​∗3=256∗256∗4+100∗256∗4+256∗30+256∗3=256∗1457
但是加入projection layer之后,假设输出维度为128,的参数个数是:
W = n c ∗ n r ∗ 4 + n i ∗ n c ∗ 4 + n r ∗ n o + n c ∗ n r + n c ∗ 3 = 256 ∗ 128 ∗ 4 + 100 ∗ 256 ∗ 4 + 128 ∗ 30 + 256 ∗ 128 + 256 ∗ 3 W=n_c*n_r*4+n_i*n_c*4+n_r*n_o+n_c*n_r+n_c*3 = 256*128*4 + 100*256*4+128*30+256*128+256*3 W=nc​∗nr​∗4+ni​∗nc​∗4+nr​∗no​+nc​∗nr​+nc​∗3=256∗128∗4+100∗256∗4+128∗30+256∗128+256∗3
所以减少的个数可以很容易看出来了。
相关细节未完待续!

参考网址

  • lstm(三) 模型压缩lstmp
  • Long Short-Term Memory Recurrent Neural Network Architectures for Large Scale Acoustic Modeling
  • long short-term memory based recurrent neural network architectures for large vocabulary speech recognition
  • What is the meaning of ‘projection layer’ in lstm?

LSTM trick之LSTMP相关推荐

  1. 干货丨各种机器学习任务的顶级结果(论文)汇总

    小编在 Github 上发现了一个良心项目:RedditSota 统计了各种机器学习任务的最顶级研究成果(论文),方便大家索引查阅. 项目地址:https://github.com//RedditSo ...

  2. 文本分类模型的训练、调优、蒸馏

    目前已经完成的事情 目前有12个模型. 其中BERT的七个变种模型,因为使用的是BERT-base-chinese的预训练模型和词表,所以不要自己做embedding,换上数据集和标签就可以跑 Ber ...

  3. 【deep_thoughts】30_PyTorch LSTM和LSTMP的原理及其手写复现

    文章目录 LSTM API 手写 lstm_forward 函数 LSTMP 修改 lstm_forward 函数 视频链接: 30.PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_ ...

  4. 【串讲总结】RNN、LSTM、GRU、ConvLSTM、ConvGRU、ST-LSTM

    前言 平时很少写总结性的文章,感觉还是需要阶段性总结一些可以串在一起的知识点,所以这次写了下.因为我写的内容主要在时序.时空预测这个方向,所以主要还是把rnn,lstm,gru,convlstm,co ...

  5. 工业界如何解决NER问题?12个trick,与你分享~

    NER是一个已经解决了的问题吗?或许,一切才刚刚开始. 例如,面对下面笔者在工作中遇到的12个关于NER的系列问题,你有什么好的trick呢?不着急,让我们通过本篇文章,逐一解答- Q1.如何快速有效 ...

  6. 有哪些LSTM(Long Short Term Memory)和RNN(Recurrent)网络的教程?

    知乎用户,阿里巴巴数据应用部门长期招聘「算法,分- 500 人赞同 刚好毕设相关,论文写完顺手就答了 先给出一个最快的了解+上手的教程: 直接看theano官网的LSTM教程+代码:LSTM Netw ...

  7. 【图文并茂】RNN、LSTM、GRU、ConvLSTM、ConvGRU、ST-LSTM的总结

    前言 平时很少写总结性的文章,感觉还是需要阶段性总结一些可以串在一起的知识点,所以这次写了下.因为我写的内容主要在时序.时空预测这个方向,所以主要还是把rnn,lstm,gru,convlstm,co ...

  8. lstm 文本纠错_工业界纠错系统

    本篇文章,主要来唠嗑下工业界的纠错系统怎么设计?包括:基于规则的纠错系统(简单的英文纠错.复杂的中文纠错).基于NN的纠错系统.当然,在成熟的工业界纠错系统中,最好是结合规则&&NN方 ...

  9. 深度学习(二十三)——Fast Image Processing, SVDF, LCNN, LSTM进阶

    https://antkillerfarm.github.io/ Fast Image Processing(续) 这个课题一般使用MIT-Adobe FiveK Dataset作为基准数据集.网址: ...

最新文章

  1. 基于Android5.0的Camera Framework源码分析 (三)
  2. [OI]Noip 2018 题解总结(普及)
  3. Docker 下安装 Spark
  4. WebAPI(part9)--下拉菜单及留言案例
  5. Attachment assignment block里选择的文件是如何传到Netweaver服务器的
  6. matlab gui教程 计算器,matlab gui编写的计算器程序
  7. Freebsd屏幕字体的调节
  8. 谷粒商城:04. 逆向工程完善微服务系统
  9. c语言求婚代码大全,求一个C语言表白的代码
  10. mysql 查询条件为空则_MySql当查询条件为空时不作为条件查询
  11. 被割裂的数据思维(古代战争中的应用)
  12. 计算机视觉作业(三)Scene Recognition with Bag of Words
  13. 计算机软件企业申请商标,软件商标注册申请流程
  14. 初学者-CSS思维导图(上)
  15. 黑马JavaWeb全功能综合案例(element-ui+mybatis+Vue+ajax)
  16. 网络安全架构部署:Fail Closed,Fail Open,Fail safe,Fail over是什么意思?
  17. 练习-编写求阶乘函数
  18. 点击DIV显示改变边框颜色
  19. 图片的混合空间增强操作Opencv-python实现
  20. lowCode与D2C

热门文章

  1. 代码编辑器--5.21
  2. 【转】ATF中SMC深入理解
  3. 通过伴随矩阵怎么求逆矩阵
  4. Java 添加、替换、删除PDF中的图片
  5. LTE/EPC中,MME怎么找到UE的HSS的?
  6. 单元格等于计算机日期,Excel相邻单元格快速填入相同日期的几种方法
  7. packetdrill 深入理解内核网络协议栈的工具集
  8. 数通基础-TCPIP参考模型
  9. 论文笔记(二):基于卷积神经网络的高分辨率遥感图像上的水体识别技术
  10. html 如何去掉超链接下的下划线