Bert-BiLSTM-CRF pytorch 代码解析-3:def _viterbi_decode
理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.
这部分用于解码阶段
def _viterbi_decode(self, feats, mask=None):"""Args:feats: size=(batch_size, seq_len, self.target_size+2)mask: size=(batch_size, seq_len)Returns:decode_idx: (batch_size, seq_len), viterbi decode结果path_score: size=(batch_size, 1), 每个句子的得分"""batch_size = feats.size(0)seq_len = feats.size(1)tag_size = feats.size(-1)length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()mask = mask.transpose(1, 0).contiguous()ins_num = seq_len * batch_sizefeats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)scores = scores.view(seq_len, batch_size, tag_size, tag_size)seq_iter = enumerate(scores)# record the position of the best scoreback_points = list()partition_history = list()mask = (1 - mask.long()).byte()try:_, inivalues = seq_iter.__next__()except:_, inivalues = seq_iter.next()partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)partition_history.append(partition)for idx, cur_values in seq_iter:cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)partition, cur_bp = torch.max(cur_values, 1)partition_history.append(partition.unsqueeze(-1))cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)back_points.append(cur_bp)partition_history = torch.cat(partition_history).view(seq_len, batch_size, -1).transpose(1, 0).contiguous()last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1)last_values = last_partition.expand(batch_size, tag_size, tag_size) + \self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size)_, last_bp = torch.max(last_values, 1)pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()if self.use_cuda:pad_zero = pad_zero.cuda()back_points.append(pad_zero)back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)pointer = last_bp[:, self.END_TAG_IDX]insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)back_points = back_points.transpose(1, 0).contiguous()back_points.scatter_(1, last_position, insert_last)back_points = back_points.transpose(1, 0).contiguous()decode_idx = Variable(torch.LongTensor(seq_len, batch_size))if self.use_cuda:decode_idx = decode_idx.cuda()decode_idx[-1] = pointer.datafor idx in range(len(back_points)-2, -1, -1):pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))decode_idx[idx] = pointer.view(-1).datapath_score = Nonedecode_idx = decode_idx.transpose(1, 0)return path_score, decode_idx
START_TAG_IDX, END_TAG_IDX = -2, -1feats = torch.FloatTensor([[[ 0.1938, -0.0033, -0.0786, 0.1115],[-0.0450, -0.1575, 0.0550, -0.1546],[-0.0271, -0.0669, -0.0533, -0.1674]],[[-0.0269, -0.1714, -0.0775, -0.0791],[-0.0745, -0.2008, -0.1868, 0.2168],[ 0.0703, 0.0196, 0.0457, 0.0400]]])
mask=torch.FloatTensor([[1, 1, 0],[1, 1, 1]])
# tags=torch.FloatTensor([[2, 0, 1],
# [0, 1, 3]])transitions = torch.Tensor([[ 7, 3, 0, 2],[ 2, 1, -2, 5],[ 1, 3, -50, 30],[-1000, -1000, -1000, -1000]])
# transitions.shape: torch.Size([4,4]), 4中包含2个起止符 batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(-1)length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
mask = mask.transpose(1, 0).contiguous()
ins_num = seq_len * batch_size
feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)scores = feats + transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)print ('length_mask:\n',length_mask)
print ('mask:\n',mask)
print ('ins_num = ',ins_num)
# print ('\nfeats',feats.shape,'\n', feats)
print ('\nscores',scores.shape,'\n',scores)
'''
length_mask:tensor([[2],[3]])
mask:tensor([[1., 1.],[1., 1.],[0., 1.]])
ins_num = 6scores torch.Size([3, 2, 4, 4]) tensor([[[[ 7.1938e+00, 2.9967e+00, -7.8600e-02, 2.1115e+00],[ 2.1938e+00, 9.9670e-01, -2.0786e+00, 5.1115e+00],[ 1.1938e+00, 2.9967e+00, -5.0079e+01, 3.0111e+01],[-9.9981e+02, -1.0000e+03, -1.0001e+03, -9.9989e+02]],[[ 6.9731e+00, 2.8286e+00, -7.7500e-02, 1.9209e+00],[ 1.9731e+00, 8.2860e-01, -2.0775e+00, 4.9209e+00],[ 9.7310e-01, 2.8286e+00, -5.0077e+01, 2.9921e+01],[-1.0000e+03, -1.0002e+03, -1.0001e+03, -1.0001e+03]]],[[[ 6.9550e+00, 2.8425e+00, 5.5000e-02, 1.8454e+00],[ 1.9550e+00, 8.4250e-01, -1.9450e+00, 4.8454e+00],[ 9.5500e-01, 2.8425e+00, -4.9945e+01, 2.9845e+01],[-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],[[ 6.9255e+00, 2.7992e+00, -1.8680e-01, 2.2168e+00],[ 1.9255e+00, 7.9920e-01, -2.1868e+00, 5.2168e+00],[ 9.2550e-01, 2.7992e+00, -5.0187e+01, 3.0217e+01],[-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]],[[[ 6.9729e+00, 2.9331e+00, -5.3300e-02, 1.8326e+00],[ 1.9729e+00, 9.3310e-01, -2.0533e+00, 4.8326e+00],[ 9.7290e-01, 2.9331e+00, -5.0053e+01, 2.9833e+01],[-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],[[ 7.0703e+00, 3.0196e+00, 4.5700e-02, 2.0400e+00],[ 2.0703e+00, 1.0196e+00, -1.9543e+00, 5.0400e+00],[ 1.0703e+00, 3.0196e+00, -4.9954e+01, 3.0040e+01],[-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]]])
'''
seq_iter = enumerate(scores)
# record the position of the best score
back_points = list()
partition_history = list()
mask = (1 - mask.long()).byte()# 1. partition:计算 start 转移到各个标签 的分数 (feats + 转移矩阵中start到各标签的分数)
# 样本1: [ 0.1938, -0.0033, -0.0786, 0.1115] + [ 1, 3, -50, 30]
try:_, inivalues = seq_iter.__next__()
except:_, inivalues = seq_iter.next()
partition = inivalues[:, START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
partition_history.append(partition)print ('partition:\n',partition)
'''
partition:tensor([[[ 1.1938],[ 2.9967],[-50.0786],[ 30.1115]],[[ 0.9731],[ 2.8286],[-50.0775],[ 29.9209]]])
'''
for idx, cur_values in seq_iter:print ('\n\nidx=',idx, '\n1. cur_values:第',idx,'个step 中 LSTM输出分数 + 转移分数\n',cur_values)par_ = partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)cur_values = cur_values + par_print ('\n2. 上一个step的分数(partition):\n',par_)print ('\n3. added cur_values: LSTM输出分数 + 转移分数 + 上一个step的分数\n',cur_values)partition, cur_bp = torch.max(cur_values, 1)print ('\n4. 取出到当前步累积分数的最大值 partition 及对应的标签预测')print (' partition\n',partition)print (' cur_bp:\n', cur_bp)partition_history.append(partition.unsqueeze(-1))# print (' partition_history:\n',partition_history)cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)back_points.append(cur_bp)print ('cur_bp.masked_fill_:\n',cur_bp)print ('back_points:\n',back_points)'''idx= 1
1. cur_values:第 1 个step 中 LSTM输出分数 + 转移分数tensor([[[ 6.9550e+00, 2.8425e+00, 5.5000e-02, 1.8454e+00],[ 1.9550e+00, 8.4250e-01, -1.9450e+00, 4.8454e+00],[ 9.5500e-01, 2.8425e+00, -4.9945e+01, 2.9845e+01],[-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],[[ 6.9255e+00, 2.7992e+00, -1.8680e-01, 2.2168e+00],[ 1.9255e+00, 7.9920e-01, -2.1868e+00, 5.2168e+00],[ 9.2550e-01, 2.7992e+00, -5.0187e+01, 3.0217e+01],[-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]])2. 上一个step的分数(partition):tensor([[[ 1.1938, 1.1938, 1.1938, 1.1938],[ 2.9967, 2.9967, 2.9967, 2.9967],[-50.0786, -50.0786, -50.0786, -50.0786],[ 30.1115, 30.1115, 30.1115, 30.1115]],[[ 0.9731, 0.9731, 0.9731, 0.9731],[ 2.8286, 2.8286, 2.8286, 2.8286],[-50.0775, -50.0775, -50.0775, -50.0775],[ 29.9209, 29.9209, 29.9209, 29.9209]]])3. added cur_values: LSTM输出分数 + 转移分数 + 上一个step的分数tensor([[[ 8.1488e+00, 4.0363e+00, 1.2488e+00, 3.0392e+00],[ 4.9517e+00, 3.8392e+00, 1.0517e+00, 7.8421e+00],[-4.9124e+01, -4.7236e+01, -1.0002e+02, -2.0233e+01],[-9.6993e+02, -9.7005e+02, -9.6983e+02, -9.7004e+02]],[[ 7.8986e+00, 3.7723e+00, 7.8630e-01, 3.1899e+00],[ 4.7541e+00, 3.6278e+00, 6.4180e-01, 8.0454e+00],[-4.9152e+01, -4.7278e+01, -1.0026e+02, -1.9861e+01],[-9.7015e+02, -9.7028e+02, -9.7027e+02, -9.6986e+02]]])4. 取出到当前步累积分数的最大值 partition 及对应的标签预测partitiontensor([[8.1488, 4.0363, 1.2488, 7.8421],[7.8986, 3.7723, 0.7863, 8.0454]])cur_bp:tensor([[0, 0, 0, 1],[0, 0, 0, 1]])
cur_bp.masked_fill_:tensor([[0, 0, 0, 1],[0, 0, 0, 1]])
back_points:[tensor([[0, 0, 0, 1],[0, 0, 0, 1]])]idx= 2
1. cur_values:第 2 个step 中 LSTM输出分数 + 转移分数tensor([[[ 6.9729e+00, 2.9331e+00, -5.3300e-02, 1.8326e+00],[ 1.9729e+00, 9.3310e-01, -2.0533e+00, 4.8326e+00],[ 9.7290e-01, 2.9331e+00, -5.0053e+01, 2.9833e+01],[-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],[[ 7.0703e+00, 3.0196e+00, 4.5700e-02, 2.0400e+00],[ 2.0703e+00, 1.0196e+00, -1.9543e+00, 5.0400e+00],[ 1.0703e+00, 3.0196e+00, -4.9954e+01, 3.0040e+01],[-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]])2. 上一个step的分数(partition):tensor([[[8.1488, 8.1488, 8.1488, 8.1488],[4.0363, 4.0363, 4.0363, 4.0363],[1.2488, 1.2488, 1.2488, 1.2488],[7.8421, 7.8421, 7.8421, 7.8421]],[[7.8986, 7.8986, 7.8986, 7.8986],[3.7723, 3.7723, 3.7723, 3.7723],[0.7863, 0.7863, 0.7863, 0.7863],[8.0454, 8.0454, 8.0454, 8.0454]]])3. added cur_values: LSTM输出分数 + 转移分数 + 上一个step的分数tensor([[[ 15.1217, 11.0819, 8.0955, 9.9814],[ 6.0092, 4.9694, 1.9830, 8.8689],[ 2.2217, 4.1819, -48.8045, 31.0814],[-992.1850, -992.2248, -992.2112, -992.3253]],[[ 14.9689, 10.9182, 7.9443, 9.9386],[ 5.8426, 4.7919, 1.8180, 8.8123],[ 1.8566, 3.8059, -49.1680, 30.8263],[-991.8843, -991.9350, -991.9089, -991.9146]]])4. 取出到当前步累积分数的最大值 partition 及对应的标签预测partitiontensor([[15.1217, 11.0819, 8.0955, 31.0814],[14.9689, 10.9182, 7.9443, 30.8263]])cur_bp:tensor([[0, 0, 0, 2],[0, 0, 0, 2]])
cur_bp.masked_fill_:tensor([[0, 0, 0, 0],[0, 0, 0, 2]])
back_points:[tensor([[0, 0, 0, 1],[0, 0, 0, 1]]), tensor([[0, 0, 0, 0],[0, 0, 0, 2]])]'''
partition_history = torch.cat(partition_history).view(seq_len, batch_size, -1).transpose(1, 0).contiguous()
print ('partition_history:\n',partition_history)last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1)
print ('last_position:\n',last_position)
print ('last_partition\n',last_partition)'''
partition_history:tensor([[[ 1.1938, 2.9967, -50.0786, 30.1115],[ 8.1488, 4.0363, 1.2488, 7.8421],[ 15.1217, 11.0819, 8.0955, 31.0814]],[[ 0.9731, 2.8286, -50.0775, 29.9209],[ 7.8986, 3.7723, 0.7863, 8.0454],[ 14.9689, 10.9182, 7.9443, 30.8263]]])
last_position:tensor([[[1, 1, 1, 1]],[[2, 2, 2, 2]]])
last_partitiontensor([[[ 8.1488],[ 4.0363],[ 1.2488],[ 7.8421]],[[14.9689],[10.9182],[ 7.9443],[30.8263]]])
'''
last_values = last_partition.expand(batch_size, tag_size, tag_size) + \transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size)
_, last_bp = torch.max(last_values, 1)
pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()
print ('last_values:\n',last_values)
print('last_bp\n',last_bp)back_points.append(pad_zero)
back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
print ('back_points:',back_points)
'''
last_values:tensor([[[ 15.1488, 11.1488, 8.1488, 10.1488],[ 6.0363, 5.0363, 2.0363, 9.0363],[ 2.2488, 4.2488, -48.7512, 31.2488],[-992.1579, -992.1579, -992.1579, -992.1579]],[[ 21.9689, 17.9689, 14.9689, 16.9689],[ 12.9182, 11.9182, 8.9182, 15.9182],[ 8.9443, 10.9443, -42.0557, 37.9443],[-969.1737, -969.1737, -969.1737, -969.1737]]])
last_bptensor([[0, 0, 0, 2],[0, 0, 0, 2]])
back_points: tensor([[[0, 0, 0, 1],[0, 0, 0, 1]],[[0, 0, 0, 0],[0, 0, 0, 2]],[[0, 0, 0, 0],[0, 0, 0, 0]]])
'''
pointer = last_bp[:, END_TAG_IDX]
insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)
print ('pointer',pointer)
print ('insert_last',insert_last)
back_points = back_points.transpose(1, 0).contiguous()
print ('1. back_points',back_points)
back_points.scatter_(1, last_position, insert_last)
print ('2. back_points',back_points)
back_points = back_points.transpose(1, 0).contiguous()
print ('3. back_points',back_points)
'''
pointer tensor([2, 2])
insert_last tensor([[[2, 2, 2, 2]],[[2, 2, 2, 2]]])
1. back_points tensor([[[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 0]],[[0, 0, 0, 1],[0, 0, 0, 2],[0, 0, 0, 0]]])
2. back_points tensor([[[0, 0, 0, 1],[2, 2, 2, 2],[0, 0, 0, 0]],[[0, 0, 0, 1],[0, 0, 0, 2],[2, 2, 2, 2]]])
3. back_points tensor([[[0, 0, 0, 1],[0, 0, 0, 1]],[[2, 2, 2, 2],[0, 0, 0, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]])
'''
decode_idx = Variable(torch.LongTensor(seq_len, batch_size))
decode_idx[-1] = pointer.data
print ('decode_idx',decode_idx)
for idx in range(len(back_points)-2, -1, -1):print ('\nidx:',idx)pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))decode_idx[idx] = pointer.view(-1).dataprint ('pointer:',pointer)print ('decode_idx:',decode_idx)path_score = None
decode_idx = decode_idx.transpose(1, 0)
print ('\nreturn decode_idx',decode_idx)
'''
decode_idx tensor([[ 0, 0],[973144064, 0],[ 2, 2]])idx: 1
pointer: tensor([[2],[0]])
decode_idx: tensor([[0, 0],[2, 0],[2, 2]])idx: 0
pointer: tensor([[0],[0]])
decode_idx: tensor([[0, 0],[2, 0],[2, 2]])return decode_idx tensor([[0, 2, 2],[0, 0, 2]])
'''
Bert-BiLSTM-CRF pytorch 代码解析-3:def _viterbi_decode相关推荐
- Bert+BiLSTM+CRF实体抽取
文章目录 一.环境 二.预训练词向量 三.模型 1.BiLSTM - 不使用预训练字向量 - 使用预训练字向量 2.CRF 3.BiLSTM + CRF - 不使用预训练词向量 - 使用预训练词向量 ...
- 基于BERT+BiLSTM+CRF的中文景点命名实体识别
赵平, 孙连英, 万莹, 葛娜. 基于BERT+BiLSTM+CRF的中文景点命名实体识别. 计算机系统应用, 2020, 29(6): 169-174.http://www.c-s-a.org.cn ...
- bert+crf可以做NER,那么为什么还有bert+bi-lstm+crf ?
我在自己人工标注的一份特定领域的数据集上跑过,加上bert确实会比只用固定的词向量要好一些,即使只用BERT加一个softmax层都比不用bert的bilstm+crf强.而bert+bilstm+c ...
- pytorch代码解析:loss = y_hat - y.view(y_hat.size())
pytorch代码解析:pytorch中loss = y_hat - y.view(y_hat.size()) import torchy_hat = torch.tensor([[-0.0044], ...
- Bert模型介绍及代码解析(pytorch)
Bert(预训练模型) 动机 基于微调的NLP模型 预训练的模型抽取了足够多的信息 新的任务只需要增加一个简单的输出层 注:bert相当于只有编码器的transformer 基于transformer ...
- mapbox 修改初始位置_一行代码教你如何随心所欲初始化Bert参数(附Pytorch代码详细解读)...
微信公众号:NLP从入门到放弃 微信文章在这里(排版更漂亮,但是内置链接不太行,看大家喜欢哪个点哪个看吧): 一行代码带你随心所欲重新初始化bert的参数(附Pytorch代码详细解读)mp.wei ...
- 自然语言处理(三):传统RNN(NvsN,Nvs1,1vsN,NvsM)pytorch代码解析
文章目录 1.预备知识:深度神经网络(DNN) 2.RNN出现的意义与基本结构 3.根据输入和输出数量的网络结构分类 3.1 N vs N(输入和输出序列等长) 3.2 N vs 1(多输入单输出) ...
- ResNet论文笔记及Pytorch代码解析
注:个人学习记录 感谢B站up主"同济子豪兄"的精彩讲解,参考视频的记录 [精读AI论文]ResNet深度残差网络_哔哩哔哩_bilibili 算法的意义(大概介绍) CV史上的技 ...
- 基于BERT+BiLSTM+CRF模型与新预处理方法的古籍自动标点
摘要 古文相较于现代文不仅在用词.语法等方面存在巨大差异,还缺少标点,使人难以理解语义.采用人工方式对古文进行标点既需要有较高的文学水平,还需要对历史文化有一定了解.为提高古文自动标点的准确率,将深层 ...
最新文章
- 树上问题 ---- Codeforces Round #722 (Div. 1) C. Trees of Tranquillity [dfs序区间的性质+最大不相交区间的性质]
- qps是什么意思_面试官:说说你之前负责的系统,QPS 能达到多少?
- 如何为计算机视觉任务选择正确的标注类型
- RESTful 架构详解
- 【OpenGL从入门到精通(二)】绘制一个点
- idea远程调试修改代码_IDEA远程调试(Remote Debug)Java代码指南
- Linux C高级编程——网络编程之UDP(4)
- 外设驱动库开发笔记12:TSEV01CL55红外温度传感器驱动
- [深度学习]为什么梯度反方向是函数值下降最快的方向?
- xyntservice
- 远程控制软件老是断线怎么解决?
- HTML 计算机代码
- Vue实例对象中的属性与方法---kalrry
- PTA 7-1 输入名字,输出问候语
- 简单的创建一个小型服务器
- 重磅:阿里开启大规模校招,传已启动保密项目
- 高德地图api比例尺
- iherb中文海淘攻略-- IHERB目前的优惠
- 6n137光耦怎么测好坏_817A光耦怎么测好坏,光耦合器
- 《AR与VR开发实战》——第1章AR技术简介
热门文章
- 游戏AI—行为树研究及实现(转自月夜魔术师 https://segmentfault.com/a/1190000012397660)
- It‘s not a Bug, it‘s a Feature
- Java跳出双层for循环
- vue中实现锚点定位
- 苹果手机投影_苹果12pro有没有指纹解锁 带屏下指纹解锁具体说明
- HTML5画布椭圆形教程
- python写生日祝福语_脱单狗福利,100行Python代码,每天不同时间段定时给女友发消息...
- 电商系统设计之运费模板(上)
- Nginx (engine x) 介绍
- 根据元素ID遍历树形结构,查找到所有父元素ID。