KBQA-Bert学习记录-CRF模型
目录
一、整体架构
1.定义CRF类,初始化相关参数
2.定义forward函数
3.forword调用的函数:_validate
4.forward调用的函数:_conputer_score
5.forward调用的函数:_compute_normalizer
6.forward调用的函数:_viterbi_decode
7.外部调用的函数:decode
该项目中,使用BERT+CRF进行NER任务,因此先构造CRF模型。具体实现过程中需要注意的细节已在代码中包含。
一、整体架构
通过bert生成序列之后(其他的模型比如LSTM什么的也一样,都会生成一个预测序列),我们得到了形状是(batch_size, sentence_length, number_of_tags)的结果,也就是,对每一句话,的每一个字,有number_of_tags这么多的预测结果。假如我们的实体类型有三个"B", "I", "O",一个batch有32句话,一句话被统一成了64个单词,那么生成的结果就是(32, 64, 3)。注意这里的batch_size和sentence_length的位置,可能会由于代码的不同,调换顺序。
生成的结果就是我们的发射分数。要计算损失,我们还需要计算发射分数中,正确路径对应的分数;以及发射分数中,所有路径合起来的分数。
同时,我们还需要对所有路径合起来的分数进行处理,由于计算损失的时候,会让这个总分作为分母,因此,采用的是先取exp(),再求和sum(),再取对数log(),而这个运算只需要pytorch的一行代码即可完成:torch.logsumexp()。
最后,我们还希望得到一条最佳路径,于是需要维特比解码得到。
因此,在这个类中,我们需要定义不同的函数来实现不同的功能:
1. __init__必须定义,初始化参数
2.forward必须定义,前向传播,得到损失值。这里面会调用其他函数,用于计算损失。
3.计算正确路径分数的函数
4.计算所有路径总分的函数
5.维特比解码函数
6.能让外接调用,得到最佳路径的函数
注意:下面所有函数,都在CRF类里面,这里以分段的形式记录。
1.定义CRF类,初始化相关参数
class CRF(nn.Module):def __init__(self, num_tags : int = 2, batch_first : bool = True) -> None:super(CRF, self).__init__()self.num_tags = num_tagsself.batch_first = batch_first# start到其他(不含end)的得分self.start_transitions = nn.Parameter(torch.empty(num_tags))# 其他(不含start)到end的得分self.end_transitions = nn.Parameter(torch.empty(num_tags))# 转移分数矩阵self.transitions = nn.Parameter(torch.empty((num_tags, num_tags)))self.reset_parameters()def reset_parameters(self):'''将初始化的分数限定在-0.1到0.1之间'''init_range = 0.1nn.init.uniform_(self.start_transitions, -init_range, init_range)nn.init.uniform_(self.end_transitions, -init_range, init_range)nn.init.uniform_(self.transitions, -init_range, init_range)
2.定义forward函数
forward函数所需要的其他函数,后面补充。通过forward函数之后,返回的是我们所需要的损失值。
def forward(self, emissions: torch.Tensor,tags: torch.Tensor = None,mask: Optional[torch.ByteTensor] = None,reduction: str = 'mean') -> torch.Tensor:self._validate(emissions, tags=tags, mask=mask)# reduction:损失值模式,是均值还是求和作为损失reduction = reduction.lower()if reduction not in ('none', 'sum', 'mean', 'token_mean'):raise ValueError(f"invalid reduction {reduction}")if mask is None:mask = torch.ones_like(tags, dtype=torch.uint8)if self.batch_first:# 发射分数形状:(seq_length, batch_size, tag_num)emissions = emissions.transpose(0, 1)tags = tags.transpose(0, 1)mask = mask.transpose(0, 1)# 计算正确标签序列的发射分数和转移分数之和, shape: (batch_size, )numerator = self._cumputer_score(emissions=emissions, tags=tags, mask=mask)# 计算所有序列发射分数和转移分数之和, shape: (batch_size, )denominator = self._compute_normalizer(emissions=emissions, mask=mask)# 二者相减, shape: (batch_size, )llh = denominator - numerator# 根据不同的设定返回不同形式的分数if reduction == 'none':return llhif reduction == 'sum':return llh.sum()if reduction == 'mean':return llh.mean()if reduction == 'token_mean':return llh.sum() / mask.float().sum()
3.forword调用的函数:_validate
主要是用来确保所有输入数据的维度应该是我们所要求的维度。
def _validate(self, emissions: torch.Tensor,tags: Optional[torch.LongTensor] = None,mask: Optional[torch.ByteTensor] = None) -> None:if emissions.dim() != 3:raise ValueError(f"emissions must have dimension of 3, got{emissions.dim()}")if emissions.size(2) != self.num_tags:raise ValueError(f"expected last dimission of emission is {self.num_tags},"f"got {emissions.size(2)}")if tags is not None:if emissions.shape[:2] != mask.shape:raise ValueError(f"the first two dimensions of mask and emissions must match"f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}")no_empty_seq = not self.batch_first and mask[0].all()no_empty_seq_bf = self.batch_first and mask[:, 0].all()if not no_empty_seq and not no_empty_seq_bf:raise ValueError('mask of the first timestep must all be on.')
4.forward调用的函数:_computer_score
该函数用来计算最佳路径的分数。
def _computer_score(self, emissions: torch.Tensor,tags: torch.LongTensor,mask: torch.ByteTensor) -> torch.Tensor:# batch secondassert emissions.dim() == 3 and tags.dim() == 2assert emissions.shape[:2] == tags.shapeassert emissions.size(2) == self.num_tagsassert mask.shape == tags.shape# 每个mask,开头一定是1,否则相当于句子就没了。assert mask[0].all()seq_length, batch_size = tags.shapemask = mask.float()# start,转移到其他所有标签的分数,不包含end# 根据实际的tag的开头的词,得到从start到每句话开头的类型的分数。# 这里是start到第一个词的转移分数,shape: (batch_size,)score = self.start_transitions[tags[0]]# 接下来是预测的每句话的开头应当是什么tag,如果有3个tag,那么每个词都会有对应的三个分数,分别对应每一个tag# 但是我们实际的tag是在tags[0]里面的,而预测的三个值,分数不一定是多少# 比如实际的第一个词tag是B,预测的BIO的三个分数分别为:(0.1,0.5,04)# 那么我们把0.1这个分数加上。这个就是发射分数,也就是预测的分数。score += emissions[0, torch.arange(batch_size), tags[0]]# 至此,我们完成了从start转移到第一个词的转移分数+发射分数# 接下来是每个词到下一个词的转移分数+发射分数,全加到一块for i in range(1, seq_length):# 转移分数score += self.transitions[tags[i-1], tags[i]] * mask[i]# 发射分数score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]# 取到最后一个词的tag# 使用的mask是形如:[1,1,1,1,0,0,0],后面的0是padding的,因此没字了# 因此通过下面的方式,取到1的和,减去1,就是最后一个词的索引了。seq_end = mask.long().sum(dim=0) - 1last_tag = tags[seq_end, torch.arange(batch_size)]# 最后一个词转移到end的分数score += self.end_transitions[last_tag]return score
5.forward调用的函数:_compute_normalizer
这里计算所有路径的分数之和。并取一个logsumexp
def _compute_normalizer(self, emissions: torch.Tensor,mask: torch.ByteTensor) -> torch.Tensor:# emissions: (seq_length, batch_size, num_tags)# mask: (seq_length, batch_size)assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length = emissions.size(0)# emissions[0],因为第一个维度是句子长度,因此emissions[0]就是每一个句子的开头的词,对应的发射分数# 并且每一个分数是有num_tags这么多。因此emissions[0]就是对所有开头的词,对每一个标签预测的分数。# 再加上start标志到每一个标签的分数,就是一个整体的开头分数之和。score = self.start_transitions + emissions[0]# 接下来把所有的转移分数,发射分数全部加起来。for i in range(1, seq_length):# 原来是(batch_size, num_tags), 现在是(batch_size, num_tags, 1)broadcast_score = score.unsqueeze(dim=2)# 对于第i个词,原来是(batch_size, num_tags), 现在是(batch_size, 1, num_tags)broadcast_emission = emissions[i].unsqueeze(1)# 先把开头的分数和转移矩阵加起来,便得到了开头的每一个tag,转移到其他每一个tag的概率# 再把发射矩阵加上,便得到了该单词的总分数。其中会自动使用broad cast机制next_score = broadcast_score + self.transitions + broadcast_emission# 对总分数,在第二个维度求一个对数域的分数。第二个维度,也就是转移矩阵的行# 我们求的是所有路径的总分数,要对这个分数求和。# 假如对第二个词来说,可能由第一个词的num_tags那么多的可能性过来,那么就把所有的可能性加起来# 这样得到的就是对于第二个词来说的总分数。因此,把转移矩阵的行,也就是前一个词可能的tag,全部加起来即可# 也就是在第二个维度上求和。这样就得到了总分数,我们对这个总分数进行对数域计算即可(取e,求和,取对数)。next_score = torch.logsumexp(next_score, dim=1)# 通过mask,如果对应的单词位置有值,也就是我们需要这个分数,那么就使用next_score# 如果对应的位置没值,那么这个分数不需要加上,就取原来的scorescore = torch.where(mask[i].unsqueeze(1), next_score, score)# 最后把单词转移到end的分数加上score += self.end_transitions# 返回值取对数域的值,把所有的词的分数再求和一遍return torch.logsumexp(score, dim=1)
6.forward调用的函数:_viterbi_decode
维特比解码,得到最佳路径。
def _viterbi_decode(self, emissions: torch.FloatTensor,mask: torch.ByteTensor) -> List[List[int]]:assert emissions.dim() == 3 and mask.dim() == 2assert emissions.shape[:2] == mask.shapeassert emissions.size(2) == self.num_tagsassert mask[0].all()seq_length, batch_size = mask.shapescore = self.start_transitions + emissions[0]history = []for i in range(1, seq_length):broadcast_score = score.unsqueeze(2)broadcast_emission = emissions[i].unsqueeze(1)next_score = broadcast_score + self.transitions + broadcast_emission# 在第一个维度上面求最大,消掉第一个维度,那么剩下的就是"到下一个类型概率最大的那个类型"# 这个max返回值有2个,一个是求完最大值后的结果,形状是(B, tag_num),一个是每个最大值所在的索引# 两个返回结果形状一致# 选最好的转移分数next_score, indices = next_score.max(dim=1)score = torch.where(mask[i].unsqueeze(1), next_score, score)# 上一个词转移到这个词时,分数最高的那些值的索引history.append(indices)score += self.end_transitionsseq_ends = mask.long().sum(dim=0) - 1best_tags_list = []for idx in range(batch_size):# 取到分数最高的标签,就是最后一个词的标签的索引# 选最好的发射分数_, best_last_tag = score[idx].max(dim=0)best_tags = [best_last_tag.item()]# seq_ends存了每个句子序列的最后一个词的索引。for hist in reversed(history[:seq_ends[idx]]):best_last_tag = hist[idx][best_tags[-1]]best_tags.append(best_last_tag.item())best_tags.reverse()best_tags_list.append(best_tags)return best_tags_list
7.外部调用的函数:decode
该函数调用了上面的维特比解码,外部可通过model.decode调用,返回最佳路径。
def decode(self, emissions: torch.Tensor,mask: Optional[torch.ByteTensor]=None) -> List[List[int]]:self._validate(emissions=emissions, mask=mask)if mask is None:mask = emissions.new_ones(emissions.shape[:2],dtype=torch.uint8)if self.batch_first:emissions = emissions.transpose(0, 1)mask = mask.transpose(0, 1)return self._viterbi_decode(emissions, mask)
KBQA-Bert学习记录-CRF模型相关推荐
- 数学建模学习记录——数学规划模型
数学建模学习记录--数学规划模型 一.线性规划问题 MatLab中线性规划的标准型 MatLab中求解线性规划的命令 二.整数线性规划问题 三.非线性规划问题 MatLab中非线性规划的标准型 Mat ...
- seq2seq模型_Pytorch学习记录-Seq2Seq模型对比
Pytorch学习记录-torchtext和Pytorch的实例4 0. PyTorch Seq2Seq项目介绍 在完成基本的torchtext之后,找到了这个教程,<基于Pytorch和tor ...
- DAB-Deformable-DETR代码学习记录之模型构建
DAB-DETR的作者在Deformable-DETR基础上,将DAB-DETR的思想融入到了Deformable-DETR中,取得了不错的成绩.今天博主通过源码来学习下DAB-Deformable- ...
- 【美赛学习记录】模型
美赛学习记录-2022年2月7日 代码! 线性回归 数据拟合 插值 最优化求极值 插值 ARIMA 复杂网络实验 模型验证 K-Fold Cross-validation k折交叉验证 [基础模型] ...
- DAB-Deformable-DETR源码学习记录之模型构建(二)
书接上回,上篇博客中我们学习到了Encoder模块,接下来我们来学习Decoder模块其代码是如何实现的. 其实Deformable-DETR最大的创新在于其提出了可变形注意力模型以及多尺度融合模块: ...
- 学习记录——Pytorch模型移植Android小例子
提示:注意文章时效性,2022.04.02. 目录 前言 零.使用的环境 一.模型准备 1.导出模型 2.错误记录 2.1要载入完整模型(网络结构+权重参数) 2.2导出的模型文件格式 二.Andro ...
- 采用 ALSTM 模型的温度和降雨关联预测研究论文学习记录
为了准确和及时预测局部区域的降雨及温度,提出了一种基于 Attention 和 LSTM 组合模型( ALSTM) 的关联多值预测算法.该算法利用天气时间序列中 的前期数据,对下一小时的降雨量和温度进 ...
- [Django]模型学习记录篇--基础
模型学习记录篇,仅仅自己学习时做的记录!!! 实现模型变更的三个步骤: 修改你的模型(在models.py文件中). 运行python manage.py makemigrations ,为这些修改创 ...
- 西瓜书学习记录-模型评估与选择(第二章)
西瓜书学习记录-模型评估与选择 第二章啦 整个过程可以描述为在训练集上去训练,在验证集上去调参,调完参之后再到训练集上去训练,直到结果满意,最后到测试集上去测试. 例子(反例): 上图选择蓝色的线,坏 ...
- Unity学习记录——模型与动画
Unity学习记录--模型与动画 前言 本文是中山大学软件工程学院2020级3d游戏编程与设计的作业7 编程题:智能巡逻兵 1.学习参考 除去老师在课堂上讲的内容,本次作业代码与操作主要参考了 ...
最新文章
- 找java培训机构如何挑选
- 【渝粤题库】广东开放大学 文化活动策划与组织 形成性考核
- java 2d划线 刷子_月光软件站 - 编程文档 - Java - Java图形设计中,利用Bresenham算法实现直线线型,线宽的控制(NO 2D GRAPHICS)...
- MySQL学习笔记17:别名
- 使用TensorFlow.js在浏览器中进行深度学习入门
- java http data chunk_HTTP协议之Chunked解析
- 即将上线的Kafka 集群(用CM部署的)无法使用“--bootstrap-server”进行消费,怎么破?...
- 支付1000元咨询费,如何让PB编写的程序不能被反编译?
- java proj4j 兰勃特投影设置地球半径 (+R )无效问题
- APUE 第四章总结
- 8 随机积分与随机微分方程
- imagine php,使用imagine/imagine实现制作一个图片
- 玩客云刷Armbian详细教程
- 【随笔杂记】电脑断电自启+远程控制自启
- Sass 3 的环境搭建及开发
- echarts-特殊需求
- Licode—基于webrtc的SFU/MCU实现
- matplotlib 进阶之Artist tutorial(如何操作Atrist和定制)
- 什么是LIDAR(激光雷达),如何标注激光点云数据?
- 今年十月最新语言排行榜
热门文章
- ppt讲解中的过渡_ppt过渡页的设计技巧
- 小说关于计算机名称,小说取名和人名取名太纠结了,感觉橙瓜码字的自动取名还不错...
- python内嵌浏览器_内嵌web浏览器
- python pycharm anaconda需要都下载吗_Anaconda下载与安装、PyCharm下载与安装
- 三菱plc pwm指令_三菱PLC必会编程指令汇总,收藏这些就够了!
- IDEA 当前项目jdk版本查看
- 基于单片机USB接口的温度控制器
- 线性系统和非线性系统
- Android四大组件的作用
- 微信小程序demo汇总