1.前置小知识

1)log_sum_exp

这个是升级版的softmax,防止出现上溢或下溢,详见关于LogSumExp - 知乎

# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):max_score = vec[0, argmax(vec)]max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])return max_score + \torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

2)维特比算法

这个是已知发射矩阵和转移矩阵,算经过一个图最大得分的,主要就是每两步就删一些必不可能得到最大分的路经,如图

都是到达第二列A点的(第一列是之前的总分)这五条路必有一条最高分,那么我们只取最高分的到达A的路经就好了,再加上A节点本身的分数,作为新的A节点的总分,BCDE四点也如法炮制,于是第二列的总分就算出来了,再找到它们到第三列一点的最大分数,一次类推,于是算整条路经的最大得分时,只要做一个循环就够了,复杂度大大降低,这里有个小细节,因为句子总会从<start>开始,所以最开始的第一列总分我们除把<start>那个点设为0外其余都设一个很大的负数,来保证第一列引到第二列的最短路径都是从<start>开始的

 def _viterbi_decode(self, feats):backpointers = []# Initialize the viterbi variables in log spaceinit_vvars = torch.full((1, self.tagset_size), -10000.)init_vvars[0][self.tag_to_ix[START_TAG]] = 0# forward_var at step i holds the viterbi variables for step i-1forward_var = init_vvarsfor feat in feats:bptrs_t = []  # holds the backpointers for this stepviterbivars_t = []  # holds the viterbi variables for this stepfor next_tag in range(self.tagset_size):# next_tag_var[i] holds the viterbi variable for tag i at the# previous step, plus the score of transitioning# from tag i to next_tag.# We don't include the emission scores here because the max# does not depend on them (we add them in below)next_tag_var = forward_var + self.transitions[next_tag]best_tag_id = argmax(next_tag_var)bptrs_t.append(best_tag_id)viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))# Now add in the emission scores, and assign forward_var to the set# of viterbi variables we just computedforward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)backpointers.append(bptrs_t)# Transition to STOP_TAGterminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]best_tag_id = argmax(terminal_var)path_score = terminal_var[0][best_tag_id]# Follow the back pointers to decode the best path.best_path = [best_tag_id]for bptrs_t in reversed(backpointers):best_tag_id = bptrs_t[best_tag_id]best_path.append(best_tag_id)# Pop off the start tag (we dont want to return that to the caller)start = best_path.pop()assert start == self.tag_to_ix[START_TAG]  # Sanity checkbest_path.reverse()return path_score, best_path

3)我们要干啥

给定一个句子,找到概率最大的label组合,如图

2 正文

BiLSTM_CRF模型共分两个部分,一部分是BiLSTM部分,一部分是CRF部分

先从简单的BiLSTM部分开始吧,与之有关的主要代码如下

在初始化中,建立一个正常的BiLstm模型,输出的大小为target(标签)的种类多少

而将每一个字的输出都连起来,拼成一个矩阵,就是所谓的发射矩阵,如图,代码如下:

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):super(BiLSTM_CRF, self).__init__()self.embedding_dim = embedding_dimself.hidden_dim = hidden_dimself.vocab_size = vocab_sizeself.tag_to_ix = tag_to_ixself.tagset_size = len(tag_to_ix)self.word_embeds = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,num_layers=1, bidirectional=True)# Maps the output of the LSTM into tag space.self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
    def _get_lstm_features(self, sentence):#这个函数已经算出状态方程(发射矩阵了),就是feats(len(sentence),tag_num)self.hidden = self.init_hidden()embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)lstm_out, self.hidden = self.lstm(embeds, self.hidden)lstm_out = lstm_out.view(len(sentence), self.hidden_dim)lstm_feats = self.hidden2tag(lstm_out)return lstm_feats

而BiLSTM模型需要给出根据X预测的输出Y的序列进行更新,即需要前文讲到的维特比算法进行更新,什么,你问正常的BiLSTM算法不是由softmax前向更新吗,这当然是因为BiLSTM不再是独立的了,而是隐藏在CRF之下,也就是它有权利给一个在CRF看起来不可能产生的单词一个很大的权重,比如跳跳真可爱但看跳这个字判为O的可能性很大,LSTM也有权利这样做,反正最后维特比路经不会选它而已,如果只是分立的用BiLSTM训练,那么就有可能告诉模型跳是B(句子标注)干扰LSTM层,代码如下

    def forward(self, sentence):  # dont confuse this with _forward_alg above.# Get the emission scores from the BiLSTMlstm_feats = self._get_lstm_features(sentence)# Find the best path, given the features.score, tag_seq = self._viterbi_decode(lstm_feats)return score, tag_seq

现在我们有了发射矩阵(x--yi关系),其实就相当于把CRF部分的X部分需要的数据准备好了,而我们还差一个转移矩阵,即(yi-1--yi)的关系,其建立代码如下

        #转移矩阵,从一个标签转移到到另一个标签的概率self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))# 行(第一个参数是被转移到的),列(第二个参数是向外转移的),开始不可能被任何标签转移,结束不可能转移到任何标签self.transitions.data[tag_to_ix[START_TAG], :] = -10000self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

是的,你没有看错,它是先随机建立的,但要规定两个参数,即<start>标签永远不可能被转移到,<stop>标签永远不可能转移到其他标签,人工给他们设低,如图

好了,现在我们两个矩阵都有了,那怎么计算最大的p(y/x)呢,即CRF模型前向计算,代码如下

 def _forward_alg(self, feats):# Do the forward algorithm to compute the partition functioninit_alphas = torch.full((1, self.tagset_size), -10000.)#([[-10000,-10000,-10000,-10000,-10000]])# START_TAG has all of the score.init_alphas[0][self.tag_to_ix[START_TAG]] = 0.#([0,-10000,-10000,-10000,-10000])#总的来说就是让不从start开始的分数低# Wrap in a variable so that we will get automatic backpropforward_var = init_alphas# Iterate through the sentencefor feat in feats:alphas_t = []  # The forward tensors at this timestepfor next_tag in range(self.tagset_size):# broadcast the emission score: it is the same regardless of# the previous tagemit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)# the ith entry of trans_score is the score of transitioning to# next_tag from i计算发射分数trans_score = self.transitions[next_tag].view(1, -1)#看next_tag被转移到# The ith entry of next_tag_var is the value for the# edge (i -> next_tag) before we do log-sum-expnext_tag_var = forward_var + trans_score + emit_score# The forward variable for this tag is log-sum-exp of all the# scores.alphas_t.append(log_sum_exp(next_tag_var).view(1))forward_var = torch.cat(alphas_t).view(1, -1)terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]alpha = log_sum_exp(terminal_var)return alpha

至于这段代码的原理,前辈已经解释的很清楚啦,详见下面学习日志中的已知模型求边缘概率[学习日志]白板推导-条件随机场 CRF Conditional Random Field_烫烫烫烫的若愚的博客-CSDN博客h

好了,现在就只剩损失函数没求了,作者将它定义为按标注的tag计算的p(y/x)与算得的p(y/x)的差

代码如图

   def neg_log_likelihood(self, sentence, tags):feats = self._get_lstm_features(sentence)forward_score = self._forward_alg(feats)gold_score = self._score_sentence(feats, tags)return forward_score - gold_score

写在后面:深度学习小白一只,读了很多关于CRF的介绍,都觉得公式很多,萌新看的有点吃力,故废了大力气自认些许明白以后想写一篇文章节省后来人的时间,因为全靠自己理解,必有差错之处,希望路过的大佬指正!

Advanced: Making Dynamic Decisions and the Bi-LSTM CRF(对官网发的代码的一些理解~,来自看不懂英文也没学过前置课程的小白两周的心血)相关推荐

  1. 3.1 TMO MATLAB 框架(Advanced High Dynamic Range Imaging )

    3.1 TMO MATLAB 框架(Advanced High Dynamic Range Imaging ) 通常,无论属于哪一类TMO,都有两个共同的步骤. 本节描述了大多数但不是全部TMO使用的 ...

  2. 3.2.3 Quantization Techniques(HDR量化)(Advanced High Dynamic Range Imaging)Schlick TMO

    3.2.3 Quantization Techniques(HDR量化)(Advanced High Dynamic Range Imaging)Schlick TMO Schlick [341]提出 ...

  3. 2.1.2 Capturing HDR Videos(Advanced High Dynamic Range Imaging )

    2.1.2 Capturing HDR Videos(Advanced High Dynamic Range Imaging ) 目录 2.1.2 Capturing HDR Videos(Advan ...

  4. pytorch lstm crf 代码理解 重点

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  5. pytorch lstm crf 代码理解

    好久没有写博客了,这一次就将最近看的pytorch 教程中的lstm+crf的一些心得与困惑记录下来. 原文 PyTorch Tutorials 参考了很多其他大神的博客,https://blog.c ...

  6. lstm代码_贼好理解,这个项目教你如何用百行代码搞定各类NLP模型

    机器之心报道 参与:思源.贾伟 NLP 的研究,从词嵌入到 CNN,再到 RNN,再到 Attention,以及现在正红火的 Transformer,模型已有很多,代码库也成千上万.对于初学者如何把握 ...

  7. 谷歌大脑科学家亲解 LSTM:一个关于“遗忘”与“记忆”的故事 本文作者:奕欣 2017-01-14 09:46 导语:AI科技评论保证这是相对通俗易懂的一篇入门介绍了,看不懂的话欢迎关注「AI 科技

    谷歌大脑科学家亲解 LSTM:一个关于"遗忘"与"记忆"的故事 本文作者:奕欣 2017-01-14 09:46 导语:AI科技评论保证这是相对通俗易懂的一篇入 ...

  8. 知识图谱-LSTM+CRF人物关系抽取实战

    文章目录 一.引言 二.实践简介 1.数据来源 2.预测类别(7个) 3.框架 4.模型结构 5.项目流程 三.数据标注 四.实战 1.数据预处理 1.1 词典映射 1.2 从训练文件中获取句子和标签 ...

  9. 命名实体识别学习-用lstm+crf处理conll03数据集

    title: 命名实体识别学习-用lstm+crf处理conll03数据集 date: 2020-07-18 16:32:31 tags: 命名实体识别学习-用lstm+crf处理conll03数据集 ...

最新文章

  1. php 操作分表代码
  2. python学习笔记(十一)——正则表达式
  3. 成功解决ValueError: fill value must be in categories
  4. dump java 内存_Java如何dump对象的内存
  5. 空间数据引擎oracle_GIS 与Oracle 数据库空间数据格式的转换
  6. return和break的区别
  7. bootstrap 栅格系统实现类似table跨行
  8. 时间管理类APP的Demo版本
  9. php 零宽断言,正则表达式之零宽断言
  10. jQuery中的阻止默认行为
  11. 男孩应该懂的,女孩应该懂的
  12. python thread start_Python中Thread类的start和run方法的区别
  13. 学习笔记(01):10小时掌握区块链开发教程-2小时构建以太坊智能合约-1
  14. 艺术家艾·克里斯汀·麦克拉维·佩恩gp采访23
  15. 树莓派 电脑通过界面远程控制
  16. 计算机控制技术直流电机调速控制实验报告,pid直流电机转速控制实验报告(31页)-原创力文档...
  17. 中国护照含金量再上升,Qbao Network 教你玩转全球54个国家!(二)
  18. vue项目添加百度统计
  19. 瞎折腾篇:联想笔记本外扩GTX1060——刷BIOS
  20. HTML5淡黄色企业品牌专题网站模板

热门文章

  1. 物联网卡流量套餐更换方法
  2. Catalan数简介
  3. react生命周期整理
  4. disconf java_java disconf使用详解
  5. pywinauto 获取朋友圈并将名片分享给好友
  6. 资料下载——《车载SOA软件架构技术规范1.0》
  7. 第78句 2020年地球日:霍金留给世界的遗言比以往任何时候都更有意义
  8. 正确的提问方式,值得一看
  9. TexStudio使用教程
  10. 螺毗喃/螺唔嗓/六苯基双咪哇/水杨醛缩苯胺/周蔡靛兰类染料/偶氮/稠环芳香化合物/哗嗓/俘精酸配类/二芳基乙烯化合物