理解 github上代码:Bert-BiLSTM-CRF-pytorch
Github 相关链接: link.

这部分用于解码阶段

def _viterbi_decode(self, feats, mask=None):"""Args:feats: size=(batch_size, seq_len, self.target_size+2)mask: size=(batch_size, seq_len)Returns:decode_idx: (batch_size, seq_len), viterbi decode结果path_score: size=(batch_size, 1), 每个句子的得分"""batch_size = feats.size(0)seq_len = feats.size(1)tag_size = feats.size(-1)length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()mask = mask.transpose(1, 0).contiguous()ins_num = seq_len * batch_sizefeats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)scores = scores.view(seq_len, batch_size, tag_size, tag_size)seq_iter = enumerate(scores)# record the position of the best scoreback_points = list()partition_history = list()mask = (1 - mask.long()).byte()try:_, inivalues = seq_iter.__next__()except:_, inivalues = seq_iter.next()partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)partition_history.append(partition)for idx, cur_values in seq_iter:cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)partition, cur_bp = torch.max(cur_values, 1)partition_history.append(partition.unsqueeze(-1))cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)back_points.append(cur_bp)partition_history = torch.cat(partition_history).view(seq_len, batch_size, -1).transpose(1, 0).contiguous()last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1)last_values = last_partition.expand(batch_size, tag_size, tag_size) + \self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size)_, last_bp = torch.max(last_values, 1)pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()if self.use_cuda:pad_zero = pad_zero.cuda()back_points.append(pad_zero)back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)pointer = last_bp[:, self.END_TAG_IDX]insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)back_points = back_points.transpose(1, 0).contiguous()back_points.scatter_(1, last_position, insert_last)back_points = back_points.transpose(1, 0).contiguous()decode_idx = Variable(torch.LongTensor(seq_len, batch_size))if self.use_cuda:decode_idx = decode_idx.cuda()decode_idx[-1] = pointer.datafor idx in range(len(back_points)-2, -1, -1):pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))decode_idx[idx] = pointer.view(-1).datapath_score = Nonedecode_idx = decode_idx.transpose(1, 0)return path_score, decode_idx
START_TAG_IDX, END_TAG_IDX = -2, -1feats = torch.FloatTensor([[[ 0.1938, -0.0033, -0.0786,  0.1115],[-0.0450, -0.1575,  0.0550, -0.1546],[-0.0271, -0.0669, -0.0533, -0.1674]],[[-0.0269, -0.1714, -0.0775, -0.0791],[-0.0745, -0.2008, -0.1868,  0.2168],[ 0.0703,  0.0196,  0.0457,  0.0400]]])
mask=torch.FloatTensor([[1, 1, 0],[1, 1, 1]])
# tags=torch.FloatTensor([[2, 0, 1],
#                         [0, 1, 3]])transitions = torch.Tensor([[    7,     3,   0,     2],[    2,     1,  -2,     5],[    1,     3,  -50,    30],[-1000, -1000, -1000, -1000]])
# transitions.shape: torch.Size([4,4]), 4中包含2个起止符   batch_size = feats.size(0)
seq_len = feats.size(1)
tag_size = feats.size(-1)length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
mask = mask.transpose(1, 0).contiguous()
ins_num = seq_len * batch_size
feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)scores = feats + transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
scores = scores.view(seq_len, batch_size, tag_size, tag_size)print ('length_mask:\n',length_mask)
print ('mask:\n',mask)
print ('ins_num = ',ins_num)
# print ('\nfeats',feats.shape,'\n', feats)
print ('\nscores',scores.shape,'\n',scores)
'''
length_mask:tensor([[2],[3]])
mask:tensor([[1., 1.],[1., 1.],[0., 1.]])
ins_num =  6scores torch.Size([3, 2, 4, 4]) tensor([[[[ 7.1938e+00,  2.9967e+00, -7.8600e-02,  2.1115e+00],[ 2.1938e+00,  9.9670e-01, -2.0786e+00,  5.1115e+00],[ 1.1938e+00,  2.9967e+00, -5.0079e+01,  3.0111e+01],[-9.9981e+02, -1.0000e+03, -1.0001e+03, -9.9989e+02]],[[ 6.9731e+00,  2.8286e+00, -7.7500e-02,  1.9209e+00],[ 1.9731e+00,  8.2860e-01, -2.0775e+00,  4.9209e+00],[ 9.7310e-01,  2.8286e+00, -5.0077e+01,  2.9921e+01],[-1.0000e+03, -1.0002e+03, -1.0001e+03, -1.0001e+03]]],[[[ 6.9550e+00,  2.8425e+00,  5.5000e-02,  1.8454e+00],[ 1.9550e+00,  8.4250e-01, -1.9450e+00,  4.8454e+00],[ 9.5500e-01,  2.8425e+00, -4.9945e+01,  2.9845e+01],[-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],[[ 6.9255e+00,  2.7992e+00, -1.8680e-01,  2.2168e+00],[ 1.9255e+00,  7.9920e-01, -2.1868e+00,  5.2168e+00],[ 9.2550e-01,  2.7992e+00, -5.0187e+01,  3.0217e+01],[-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]],[[[ 6.9729e+00,  2.9331e+00, -5.3300e-02,  1.8326e+00],[ 1.9729e+00,  9.3310e-01, -2.0533e+00,  4.8326e+00],[ 9.7290e-01,  2.9331e+00, -5.0053e+01,  2.9833e+01],[-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],[[ 7.0703e+00,  3.0196e+00,  4.5700e-02,  2.0400e+00],[ 2.0703e+00,  1.0196e+00, -1.9543e+00,  5.0400e+00],[ 1.0703e+00,  3.0196e+00, -4.9954e+01,  3.0040e+01],[-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]]])
'''
seq_iter = enumerate(scores)
# record the position of the best score
back_points = list()
partition_history = list()
mask = (1 - mask.long()).byte()# 1. partition:计算 start 转移到各个标签 的分数 (feats + 转移矩阵中start到各标签的分数)
# 样本1: [ 0.1938, -0.0033, -0.0786,  0.1115] + [    1,     3,  -50,    30]
try:_, inivalues = seq_iter.__next__()
except:_, inivalues = seq_iter.next()
partition = inivalues[:, START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
partition_history.append(partition)print ('partition:\n',partition)
'''
partition:tensor([[[  1.1938],[  2.9967],[-50.0786],[ 30.1115]],[[  0.9731],[  2.8286],[-50.0775],[ 29.9209]]])
'''
for idx, cur_values in seq_iter:print ('\n\nidx=',idx, '\n1. cur_values:第',idx,'个step 中 LSTM输出分数 + 转移分数\n',cur_values)par_ = partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)cur_values = cur_values + par_print ('\n2. 上一个step的分数(partition):\n',par_)print ('\n3. added cur_values: LSTM输出分数 + 转移分数 + 上一个step的分数\n',cur_values)partition, cur_bp = torch.max(cur_values, 1)print ('\n4. 取出到当前步累积分数的最大值 partition 及对应的标签预测')print ('  partition\n',partition)print ('  cur_bp:\n', cur_bp)partition_history.append(partition.unsqueeze(-1))# print ('  partition_history:\n',partition_history)cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)back_points.append(cur_bp)print ('cur_bp.masked_fill_:\n',cur_bp)print ('back_points:\n',back_points)'''idx= 1
1. cur_values:第 1 个step 中 LSTM输出分数 + 转移分数tensor([[[ 6.9550e+00,  2.8425e+00,  5.5000e-02,  1.8454e+00],[ 1.9550e+00,  8.4250e-01, -1.9450e+00,  4.8454e+00],[ 9.5500e-01,  2.8425e+00, -4.9945e+01,  2.9845e+01],[-1.0000e+03, -1.0002e+03, -9.9995e+02, -1.0002e+03]],[[ 6.9255e+00,  2.7992e+00, -1.8680e-01,  2.2168e+00],[ 1.9255e+00,  7.9920e-01, -2.1868e+00,  5.2168e+00],[ 9.2550e-01,  2.7992e+00, -5.0187e+01,  3.0217e+01],[-1.0001e+03, -1.0002e+03, -1.0002e+03, -9.9978e+02]]])2. 上一个step的分数(partition):tensor([[[  1.1938,   1.1938,   1.1938,   1.1938],[  2.9967,   2.9967,   2.9967,   2.9967],[-50.0786, -50.0786, -50.0786, -50.0786],[ 30.1115,  30.1115,  30.1115,  30.1115]],[[  0.9731,   0.9731,   0.9731,   0.9731],[  2.8286,   2.8286,   2.8286,   2.8286],[-50.0775, -50.0775, -50.0775, -50.0775],[ 29.9209,  29.9209,  29.9209,  29.9209]]])3. added cur_values: LSTM输出分数 + 转移分数 + 上一个step的分数tensor([[[ 8.1488e+00,  4.0363e+00,  1.2488e+00,  3.0392e+00],[ 4.9517e+00,  3.8392e+00,  1.0517e+00,  7.8421e+00],[-4.9124e+01, -4.7236e+01, -1.0002e+02, -2.0233e+01],[-9.6993e+02, -9.7005e+02, -9.6983e+02, -9.7004e+02]],[[ 7.8986e+00,  3.7723e+00,  7.8630e-01,  3.1899e+00],[ 4.7541e+00,  3.6278e+00,  6.4180e-01,  8.0454e+00],[-4.9152e+01, -4.7278e+01, -1.0026e+02, -1.9861e+01],[-9.7015e+02, -9.7028e+02, -9.7027e+02, -9.6986e+02]]])4. 取出到当前步累积分数的最大值 partition 及对应的标签预测partitiontensor([[8.1488, 4.0363, 1.2488, 7.8421],[7.8986, 3.7723, 0.7863, 8.0454]])cur_bp:tensor([[0, 0, 0, 1],[0, 0, 0, 1]])
cur_bp.masked_fill_:tensor([[0, 0, 0, 1],[0, 0, 0, 1]])
back_points:[tensor([[0, 0, 0, 1],[0, 0, 0, 1]])]idx= 2
1. cur_values:第 2 个step 中 LSTM输出分数 + 转移分数tensor([[[ 6.9729e+00,  2.9331e+00, -5.3300e-02,  1.8326e+00],[ 1.9729e+00,  9.3310e-01, -2.0533e+00,  4.8326e+00],[ 9.7290e-01,  2.9331e+00, -5.0053e+01,  2.9833e+01],[-1.0000e+03, -1.0001e+03, -1.0001e+03, -1.0002e+03]],[[ 7.0703e+00,  3.0196e+00,  4.5700e-02,  2.0400e+00],[ 2.0703e+00,  1.0196e+00, -1.9543e+00,  5.0400e+00],[ 1.0703e+00,  3.0196e+00, -4.9954e+01,  3.0040e+01],[-9.9993e+02, -9.9998e+02, -9.9995e+02, -9.9996e+02]]])2. 上一个step的分数(partition):tensor([[[8.1488, 8.1488, 8.1488, 8.1488],[4.0363, 4.0363, 4.0363, 4.0363],[1.2488, 1.2488, 1.2488, 1.2488],[7.8421, 7.8421, 7.8421, 7.8421]],[[7.8986, 7.8986, 7.8986, 7.8986],[3.7723, 3.7723, 3.7723, 3.7723],[0.7863, 0.7863, 0.7863, 0.7863],[8.0454, 8.0454, 8.0454, 8.0454]]])3. added cur_values: LSTM输出分数 + 转移分数 + 上一个step的分数tensor([[[  15.1217,   11.0819,    8.0955,    9.9814],[   6.0092,    4.9694,    1.9830,    8.8689],[   2.2217,    4.1819,  -48.8045,   31.0814],[-992.1850, -992.2248, -992.2112, -992.3253]],[[  14.9689,   10.9182,    7.9443,    9.9386],[   5.8426,    4.7919,    1.8180,    8.8123],[   1.8566,    3.8059,  -49.1680,   30.8263],[-991.8843, -991.9350, -991.9089, -991.9146]]])4. 取出到当前步累积分数的最大值 partition 及对应的标签预测partitiontensor([[15.1217, 11.0819,  8.0955, 31.0814],[14.9689, 10.9182,  7.9443, 30.8263]])cur_bp:tensor([[0, 0, 0, 2],[0, 0, 0, 2]])
cur_bp.masked_fill_:tensor([[0, 0, 0, 0],[0, 0, 0, 2]])
back_points:[tensor([[0, 0, 0, 1],[0, 0, 0, 1]]), tensor([[0, 0, 0, 0],[0, 0, 0, 2]])]'''
partition_history = torch.cat(partition_history).view(seq_len, batch_size, -1).transpose(1, 0).contiguous()
print ('partition_history:\n',partition_history)last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1)
print ('last_position:\n',last_position)
print ('last_partition\n',last_partition)'''
partition_history:tensor([[[  1.1938,   2.9967, -50.0786,  30.1115],[  8.1488,   4.0363,   1.2488,   7.8421],[ 15.1217,  11.0819,   8.0955,  31.0814]],[[  0.9731,   2.8286, -50.0775,  29.9209],[  7.8986,   3.7723,   0.7863,   8.0454],[ 14.9689,  10.9182,   7.9443,  30.8263]]])
last_position:tensor([[[1, 1, 1, 1]],[[2, 2, 2, 2]]])
last_partitiontensor([[[ 8.1488],[ 4.0363],[ 1.2488],[ 7.8421]],[[14.9689],[10.9182],[ 7.9443],[30.8263]]])
'''
last_values = last_partition.expand(batch_size, tag_size, tag_size) + \transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size)
_, last_bp = torch.max(last_values, 1)
pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()
print ('last_values:\n',last_values)
print('last_bp\n',last_bp)back_points.append(pad_zero)
back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
print ('back_points:',back_points)
'''
last_values:tensor([[[  15.1488,   11.1488,    8.1488,   10.1488],[   6.0363,    5.0363,    2.0363,    9.0363],[   2.2488,    4.2488,  -48.7512,   31.2488],[-992.1579, -992.1579, -992.1579, -992.1579]],[[  21.9689,   17.9689,   14.9689,   16.9689],[  12.9182,   11.9182,    8.9182,   15.9182],[   8.9443,   10.9443,  -42.0557,   37.9443],[-969.1737, -969.1737, -969.1737, -969.1737]]])
last_bptensor([[0, 0, 0, 2],[0, 0, 0, 2]])
back_points: tensor([[[0, 0, 0, 1],[0, 0, 0, 1]],[[0, 0, 0, 0],[0, 0, 0, 2]],[[0, 0, 0, 0],[0, 0, 0, 0]]])
'''
pointer = last_bp[:, END_TAG_IDX]
insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)
print ('pointer',pointer)
print ('insert_last',insert_last)
back_points = back_points.transpose(1, 0).contiguous()
print ('1. back_points',back_points)
back_points.scatter_(1, last_position, insert_last)
print ('2. back_points',back_points)
back_points = back_points.transpose(1, 0).contiguous()
print ('3. back_points',back_points)
'''
pointer tensor([2, 2])
insert_last tensor([[[2, 2, 2, 2]],[[2, 2, 2, 2]]])
1. back_points tensor([[[0, 0, 0, 1],[0, 0, 0, 0],[0, 0, 0, 0]],[[0, 0, 0, 1],[0, 0, 0, 2],[0, 0, 0, 0]]])
2. back_points tensor([[[0, 0, 0, 1],[2, 2, 2, 2],[0, 0, 0, 0]],[[0, 0, 0, 1],[0, 0, 0, 2],[2, 2, 2, 2]]])
3. back_points tensor([[[0, 0, 0, 1],[0, 0, 0, 1]],[[2, 2, 2, 2],[0, 0, 0, 2]],[[0, 0, 0, 0],[2, 2, 2, 2]]])
'''
decode_idx = Variable(torch.LongTensor(seq_len, batch_size))
decode_idx[-1] = pointer.data
print ('decode_idx',decode_idx)
for idx in range(len(back_points)-2, -1, -1):print ('\nidx:',idx)pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))decode_idx[idx] = pointer.view(-1).dataprint ('pointer:',pointer)print ('decode_idx:',decode_idx)path_score = None
decode_idx = decode_idx.transpose(1, 0)
print  ('\nreturn decode_idx',decode_idx)
'''
decode_idx tensor([[        0,         0],[973144064,         0],[        2,         2]])idx: 1
pointer: tensor([[2],[0]])
decode_idx: tensor([[0, 0],[2, 0],[2, 2]])idx: 0
pointer: tensor([[0],[0]])
decode_idx: tensor([[0, 0],[2, 0],[2, 2]])return decode_idx tensor([[0, 2, 2],[0, 0, 2]])
'''

Bert-BiLSTM-CRF pytorch 代码解析-3:def _viterbi_decode相关推荐

  1. Bert+BiLSTM+CRF实体抽取

    文章目录 一.环境 二.预训练词向量 三.模型 1.BiLSTM - 不使用预训练字向量 - 使用预训练字向量 2.CRF 3.BiLSTM + CRF - 不使用预训练词向量 - 使用预训练词向量 ...

  2. 基于BERT+BiLSTM+CRF的中文景点命名实体识别

    赵平, 孙连英, 万莹, 葛娜. 基于BERT+BiLSTM+CRF的中文景点命名实体识别. 计算机系统应用, 2020, 29(6): 169-174.http://www.c-s-a.org.cn ...

  3. bert+crf可以做NER,那么为什么还有bert+bi-lstm+crf ?

    我在自己人工标注的一份特定领域的数据集上跑过,加上bert确实会比只用固定的词向量要好一些,即使只用BERT加一个softmax层都比不用bert的bilstm+crf强.而bert+bilstm+c ...

  4. pytorch代码解析:loss = y_hat - y.view(y_hat.size())

    pytorch代码解析:pytorch中loss = y_hat - y.view(y_hat.size()) import torchy_hat = torch.tensor([[-0.0044], ...

  5. Bert模型介绍及代码解析(pytorch)

    Bert(预训练模型) 动机 基于微调的NLP模型 预训练的模型抽取了足够多的信息 新的任务只需要增加一个简单的输出层 注:bert相当于只有编码器的transformer 基于transformer ...

  6. mapbox 修改初始位置_一行代码教你如何随心所欲初始化Bert参数(附Pytorch代码详细解读)...

    微信公众号:NLP从入门到放弃 微信文章在这里(排版更漂亮,但是内置链接不太行,看大家喜欢哪个点哪个看吧): 一行代码带你随心所欲重新初始化bert的参数(附Pytorch代码详细解读)​mp.wei ...

  7. 自然语言处理(三):传统RNN(NvsN,Nvs1,1vsN,NvsM)pytorch代码解析

    文章目录 1.预备知识:深度神经网络(DNN) 2.RNN出现的意义与基本结构 3.根据输入和输出数量的网络结构分类 3.1 N vs N(输入和输出序列等长) 3.2 N vs 1(多输入单输出) ...

  8. ResNet论文笔记及Pytorch代码解析

    注:个人学习记录 感谢B站up主"同济子豪兄"的精彩讲解,参考视频的记录 [精读AI论文]ResNet深度残差网络_哔哩哔哩_bilibili 算法的意义(大概介绍) CV史上的技 ...

  9. 基于BERT+BiLSTM+CRF模型与新预处理方法的古籍自动标点

    摘要 古文相较于现代文不仅在用词.语法等方面存在巨大差异,还缺少标点,使人难以理解语义.采用人工方式对古文进行标点既需要有较高的文学水平,还需要对历史文化有一定了解.为提高古文自动标点的准确率,将深层 ...

最新文章

  1. 树上问题 ---- Codeforces Round #722 (Div. 1) C. Trees of Tranquillity [dfs序区间的性质+最大不相交区间的性质]
  2. qps是什么意思_面试官:说说你之前负责的系统,QPS 能达到多少?
  3. 如何为计算机视觉任务选择正确的标注类型
  4. RESTful 架构详解
  5. 【OpenGL从入门到精通(二)】绘制一个点
  6. idea远程调试修改代码_IDEA远程调试(Remote Debug)Java代码指南
  7. Linux C高级编程——网络编程之UDP(4)
  8. 外设驱动库开发笔记12:TSEV01CL55红外温度传感器驱动
  9. [深度学习]为什么梯度反方向是函数值下降最快的方向?
  10. xyntservice
  11. 远程控制软件老是断线怎么解决?
  12. HTML 计算机代码
  13. Vue实例对象中的属性与方法---kalrry
  14. PTA 7-1 输入名字,输出问候语
  15. 简单的创建一个小型服务器
  16. 重磅:阿里开启大规模校招,传已启动保密项目
  17. 高德地图api比例尺
  18. iherb中文海淘攻略-- IHERB目前的优惠
  19. 6n137光耦怎么测好坏_817A光耦怎么测好坏,光耦合器
  20. 《AR与VR开发实战》——第1章AR技术简介

热门文章

  1. 游戏AI—行为树研究及实现(转自月夜魔术师 https://segmentfault.com/a/1190000012397660)
  2. It‘s not a Bug, it‘s a Feature
  3. Java跳出双层for循环
  4. vue中实现锚点定位
  5. 苹果手机投影_苹果12pro有没有指纹解锁 带屏下指纹解锁具体说明
  6. HTML5画布椭圆形教程
  7. python写生日祝福语_脱单狗福利,100行Python代码,每天不同时间段定时给女友发消息...
  8. 电商系统设计之运费模板(上)
  9. Nginx (engine x) 介绍
  10. 根据元素ID遍历树形结构,查找到所有父元素ID。