目录

一、整体架构

1.定义CRF类,初始化相关参数

2.定义forward函数

3.forword调用的函数:_validate

4.forward调用的函数:_conputer_score

5.forward调用的函数:_compute_normalizer

6.forward调用的函数:_viterbi_decode

7.外部调用的函数:decode


该项目中,使用BERT+CRF进行NER任务,因此先构造CRF模型。具体实现过程中需要注意的细节已在代码中包含。

一、整体架构

通过bert生成序列之后(其他的模型比如LSTM什么的也一样,都会生成一个预测序列),我们得到了形状是(batch_size, sentence_length, number_of_tags)的结果,也就是,对每一句话,的每一个字,有number_of_tags这么多的预测结果。假如我们的实体类型有三个"B", "I", "O",一个batch有32句话,一句话被统一成了64个单词,那么生成的结果就是(32, 64, 3)。注意这里的batch_size和sentence_length的位置,可能会由于代码的不同,调换顺序。

生成的结果就是我们的发射分数。要计算损失,我们还需要计算发射分数中,正确路径对应的分数;以及发射分数中,所有路径合起来的分数。

同时,我们还需要对所有路径合起来的分数进行处理,由于计算损失的时候,会让这个总分作为分母,因此,采用的是先取exp(),再求和sum(),再取对数log(),而这个运算只需要pytorch的一行代码即可完成:torch.logsumexp()。

最后,我们还希望得到一条最佳路径,于是需要维特比解码得到。

因此,在这个类中,我们需要定义不同的函数来实现不同的功能:

1. __init__必须定义,初始化参数

2.forward必须定义,前向传播,得到损失值。这里面会调用其他函数,用于计算损失。

3.计算正确路径分数的函数

4.计算所有路径总分的函数

5.维特比解码函数

6.能让外接调用,得到最佳路径的函数

注意:下面所有函数,都在CRF类里面,这里以分段的形式记录。

1.定义CRF类,初始化相关参数

class CRF(nn.Module):def __init__(self, num_tags : int = 2, batch_first : bool = True) -> None:super(CRF, self).__init__()self.num_tags = num_tagsself.batch_first = batch_first# start到其他(不含end)的得分self.start_transitions = nn.Parameter(torch.empty(num_tags))# 其他(不含start)到end的得分self.end_transitions = nn.Parameter(torch.empty(num_tags))# 转移分数矩阵self.transitions = nn.Parameter(torch.empty((num_tags, num_tags)))self.reset_parameters()def reset_parameters(self):'''将初始化的分数限定在-0.1到0.1之间'''init_range = 0.1nn.init.uniform_(self.start_transitions, -init_range, init_range)nn.init.uniform_(self.end_transitions, -init_range, init_range)nn.init.uniform_(self.transitions, -init_range, init_range)

2.定义forward函数

forward函数所需要的其他函数,后面补充。通过forward函数之后,返回的是我们所需要的损失值。

    def forward(self, emissions: torch.Tensor,tags: torch.Tensor = None,mask: Optional[torch.ByteTensor] = None,reduction: str = 'mean') -> torch.Tensor:self._validate(emissions, tags=tags, mask=mask)# reduction:损失值模式,是均值还是求和作为损失reduction = reduction.lower()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f"invalid reduction {reduction}")if mask is None:mask = torch.ones_like(tags, dtype=torch.uint8)if self.batch_first:# 发射分数形状:(seq_length, batch_size, tag_num)emissions = emissions.transpose(0, 1)tags = tags.transpose(0, 1)mask = mask.transpose(0, 1)# 计算正确标签序列的发射分数和转移分数之和, shape: (batch_size, )numerator = self._cumputer_score(emissions=emissions, tags=tags, mask=mask)# 计算所有序列发射分数和转移分数之和, shape: (batch_size, )denominator = self._compute_normalizer(emissions=emissions, mask=mask)# 二者相减, shape: (batch_size, )llh = denominator - numerator# 根据不同的设定返回不同形式的分数if reduction == 'none':return llhif reduction == 'sum':return llh.sum()if reduction == 'mean':return llh.mean()if reduction == 'token_mean':return llh.sum() / mask.float().sum()

3.forword调用的函数:_validate

主要是用来确保所有输入数据的维度应该是我们所要求的维度。

    def _validate(self, emissions: torch.Tensor,tags: Optional[torch.LongTensor] = None,mask: Optional[torch.ByteTensor] = None) -> None:if emissions.dim() != 3:raise ValueError(f"emissions must have dimension of 3, got{emissions.dim()}")if emissions.size(2) != self.num_tags:raise ValueError(f"expected last dimission of emission is {self.num_tags},"f"got {emissions.size(2)}")if tags is not None:if emissions.shape[:2] != mask.shape:raise ValueError(f"the first two dimensions of mask and emissions must match"f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}")no_empty_seq = not self.batch_first and mask[0].all()no_empty_seq_bf = self.batch_first and mask[:, 0].all()if not no_empty_seq and not no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on.')

4.forward调用的函数:_computer_score

该函数用来计算最佳路径的分数。

    def _computer_score(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: torch.ByteTensor) -> torch.Tensor:# batch secondassert emissions.dim() == 3 and tags.dim() == 2assert emissions.shape[:2] == tags.shapeassert emissions.size(2) == self.num_tagsassert mask.shape == tags.shape# 每个mask,开头一定是1,否则相当于句子就没了。assert mask[0].all()seq_length, batch_size = tags.shapemask = mask.float()# start,转移到其他所有标签的分数,不包含end# 根据实际的tag的开头的词,得到从start到每句话开头的类型的分数。# 这里是start到第一个词的转移分数,shape: (batch_size,)score = self.start_transitions[tags[0]]# 接下来是预测的每句话的开头应当是什么tag,如果有3个tag,那么每个词都会有对应的三个分数,分别对应每一个tag# 但是我们实际的tag是在tags[0]里面的,而预测的三个值,分数不一定是多少# 比如实际的第一个词tag是B,预测的BIO的三个分数分别为:(0.1,0.5,04)# 那么我们把0.1这个分数加上。这个就是发射分数,也就是预测的分数。score += emissions[0, torch.arange(batch_size), tags[0]]# 至此,我们完成了从start转移到第一个词的转移分数+发射分数# 接下来是每个词到下一个词的转移分数+发射分数,全加到一块for i in range(1, seq_length):# 转移分数score += self.transitions[tags[i-1], tags[i]] * mask[i]# 发射分数score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]# 取到最后一个词的tag# 使用的mask是形如:[1,1,1,1,0,0,0],后面的0是padding的,因此没字了# 因此通过下面的方式,取到1的和,减去1,就是最后一个词的索引了。seq_end = mask.long().sum(dim=0) - 1last_tag = tags[seq_end, torch.arange(batch_size)]# 最后一个词转移到end的分数score += self.end_transitions[last_tag]return score

5.forward调用的函数:_compute_normalizer

这里计算所有路径的分数之和。并取一个logsumexp

    def _compute_normalizer(self, emissions: torch.Tensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length = emissions.size(0)# emissions[0],因为第一个维度是句子长度,因此emissions[0]就是每一个句子的开头的词,对应的发射分数# 并且每一个分数是有num_tags这么多。因此emissions[0]就是对所有开头的词,对每一个标签预测的分数。# 再加上start标志到每一个标签的分数,就是一个整体的开头分数之和。score = self.start_transitions + emissions[0]# 接下来把所有的转移分数,发射分数全部加起来。for i in range(1, seq_length):# 原来是(batch_size, num_tags), 现在是(batch_size, num_tags, 1)broadcast_score = score.unsqueeze(dim=2)# 对于第i个词,原来是(batch_size, num_tags), 现在是(batch_size, 1, num_tags)broadcast_emission = emissions[i].unsqueeze(1)# 先把开头的分数和转移矩阵加起来,便得到了开头的每一个tag,转移到其他每一个tag的概率# 再把发射矩阵加上,便得到了该单词的总分数。其中会自动使用broad cast机制next_score = broadcast_score + self.transitions + broadcast_emission# 对总分数,在第二个维度求一个对数域的分数。第二个维度,也就是转移矩阵的行# 我们求的是所有路径的总分数,要对这个分数求和。# 假如对第二个词来说,可能由第一个词的num_tags那么多的可能性过来,那么就把所有的可能性加起来# 这样得到的就是对于第二个词来说的总分数。因此,把转移矩阵的行,也就是前一个词可能的tag,全部加起来即可# 也就是在第二个维度上求和。这样就得到了总分数,我们对这个总分数进行对数域计算即可(取e,求和,取对数)。next_score = torch.logsumexp(next_score, dim=1)# 通过mask,如果对应的单词位置有值,也就是我们需要这个分数,那么就使用next_score# 如果对应的位置没值,那么这个分数不需要加上,就取原来的scorescore = torch.where(mask[i].unsqueeze(1), next_score, score)# 最后把单词转移到end的分数加上score += self.end_transitions# 返回值取对数域的值,把所有的词的分数再求和一遍return torch.logsumexp(score, dim=1)

6.forward调用的函数:_viterbi_decode

维特比解码,得到最佳路径。

    def _viterbi_decode(self, emissions: torch.FloatTensor,mask: torch.ByteTensor) -> List[List[int]]:assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length, batch_size = mask.shapescore = self.start_transitions + emissions[0]history = []for i in range(1, seq_length):broadcast_score = score.unsqueeze(2)broadcast_emission = emissions[i].unsqueeze(1)next_score = broadcast_score + self.transitions + broadcast_emission# 在第一个维度上面求最大,消掉第一个维度,那么剩下的就是"到下一个类型概率最大的那个类型"# 这个max返回值有2个,一个是求完最大值后的结果,形状是(B, tag_num),一个是每个最大值所在的索引# 两个返回结果形状一致# 选最好的转移分数next_score, indices = next_score.max(dim=1)score = torch.where(mask[i].unsqueeze(1), next_score, score)# 上一个词转移到这个词时,分数最高的那些值的索引history.append(indices)score += self.end_transitionsseq_ends = mask.long().sum(dim=0) - 1best_tags_list = []for idx in range(batch_size):# 取到分数最高的标签,就是最后一个词的标签的索引# 选最好的发射分数_, best_last_tag = score[idx].max(dim=0)best_tags = [best_last_tag.item()]# seq_ends存了每个句子序列的最后一个词的索引。for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(best_last_tag.item())best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list

7.外部调用的函数:decode

该函数调用了上面的维特比解码,外部可通过model.decode调用,返回最佳路径。

    def decode(self, emissions: torch.Tensor,mask: Optional[torch.ByteTensor]=None) -> List[List[int]]:self._validate(emissions=emissions, mask=mask)if mask is None:mask = emissions.new_ones(emissions.shape[:2],dtype=torch.uint8)if self.batch_first:emissions = emissions.transpose(0, 1)mask = mask.transpose(0, 1)return self._viterbi_decode(emissions, mask)

KBQA-Bert学习记录-CRF模型相关推荐

  1. 数学建模学习记录——数学规划模型

    数学建模学习记录--数学规划模型 一.线性规划问题 MatLab中线性规划的标准型 MatLab中求解线性规划的命令 二.整数线性规划问题 三.非线性规划问题 MatLab中非线性规划的标准型 Mat ...

  2. seq2seq模型_Pytorch学习记录-Seq2Seq模型对比

    Pytorch学习记录-torchtext和Pytorch的实例4 0. PyTorch Seq2Seq项目介绍 在完成基本的torchtext之后,找到了这个教程,<基于Pytorch和tor ...

  3. DAB-Deformable-DETR代码学习记录之模型构建

    DAB-DETR的作者在Deformable-DETR基础上,将DAB-DETR的思想融入到了Deformable-DETR中,取得了不错的成绩.今天博主通过源码来学习下DAB-Deformable- ...

  4. 【美赛学习记录】模型

    美赛学习记录-2022年2月7日 代码! 线性回归 数据拟合 插值 最优化求极值 插值 ARIMA 复杂网络实验 模型验证 K-Fold Cross-validation k折交叉验证 [基础模型] ...

  5. DAB-Deformable-DETR源码学习记录之模型构建(二)

    书接上回,上篇博客中我们学习到了Encoder模块,接下来我们来学习Decoder模块其代码是如何实现的. 其实Deformable-DETR最大的创新在于其提出了可变形注意力模型以及多尺度融合模块: ...

  6. 学习记录——Pytorch模型移植Android小例子

    提示:注意文章时效性,2022.04.02. 目录 前言 零.使用的环境 一.模型准备 1.导出模型 2.错误记录 2.1要载入完整模型(网络结构+权重参数) 2.2导出的模型文件格式 二.Andro ...

  7. 采用 ALSTM 模型的温度和降雨关联预测研究论文学习记录

    为了准确和及时预测局部区域的降雨及温度,提出了一种基于 Attention 和 LSTM 组合模型( ALSTM) 的关联多值预测算法.该算法利用天气时间序列中 的前期数据,对下一小时的降雨量和温度进 ...

  8. [Django]模型学习记录篇--基础

    模型学习记录篇,仅仅自己学习时做的记录!!! 实现模型变更的三个步骤: 修改你的模型(在models.py文件中). 运行python manage.py makemigrations ,为这些修改创 ...

  9. 西瓜书学习记录-模型评估与选择(第二章)

    西瓜书学习记录-模型评估与选择 第二章啦 整个过程可以描述为在训练集上去训练,在验证集上去调参,调完参之后再到训练集上去训练,直到结果满意,最后到测试集上去测试. 例子(反例): 上图选择蓝色的线,坏 ...

  10. Unity学习记录——模型与动画

    Unity学习记录--模型与动画 前言 ​ 本文是中山大学软件工程学院2020级3d游戏编程与设计的作业7 编程题:智能巡逻兵 1.学习参考 ​ 除去老师在课堂上讲的内容,本次作业代码与操作主要参考了 ...

最新文章

  1. 找java培训机构如何挑选
  2. 【渝粤题库】广东开放大学 文化活动策划与组织 形成性考核
  3. java 2d划线 刷子_月光软件站 - 编程文档 - Java - Java图形设计中,利用Bresenham算法实现直线线型,线宽的控制(NO 2D GRAPHICS)...
  4. MySQL学习笔记17:别名
  5. 使用TensorFlow.js在浏览器中进行深度学习入门
  6. java http data chunk_HTTP协议之Chunked解析
  7. 即将上线的Kafka 集群(用CM部署的)无法使用“--bootstrap-server”进行消费,怎么破?...
  8. 支付1000元咨询费,如何让PB编写的程序不能被反编译?
  9. java proj4j 兰勃特投影设置地球半径 (+R )无效问题
  10. APUE 第四章总结
  11. 8 随机积分与随机微分方程
  12. imagine php,使用imagine/imagine实现制作一个图片
  13. 玩客云刷Armbian详细教程
  14. 【随笔杂记】电脑断电自启+远程控制自启
  15. Sass 3 的环境搭建及开发
  16. echarts-特殊需求
  17. Licode—基于webrtc的SFU/MCU实现
  18. matplotlib 进阶之Artist tutorial(如何操作Atrist和定制)
  19. 什么是LIDAR(激光雷达),如何标注激光点云数据?
  20. 今年十月最新语言排行榜

热门文章

  1. ppt讲解中的过渡_ppt过渡页的设计技巧
  2. 小说关于计算机名称,小说取名和人名取名太纠结了,感觉橙瓜码字的自动取名还不错...
  3. python内嵌浏览器_内嵌web浏览器
  4. python pycharm anaconda需要都下载吗_Anaconda下载与安装、PyCharm下载与安装
  5. 三菱plc pwm指令_三菱PLC必会编程指令汇总,收藏这些就够了!
  6. IDEA 当前项目jdk版本查看
  7. 基于单片机USB接口的温度控制器
  8. 线性系统和非线性系统
  9. Android四大组件的作用
  10. 微信小程序demo汇总