文章目录

  • 一、为什么使用Pointer Network?
  • 二、Pointer Network的结构
    • 1.Pointer Network网络如何解决OOV问题
  • 三、如何通过结合Pointer Network处理语言生成?
    • 1.语言生成中的生成重复词的问题的解决办法
  • 四、PGN网络代码实现

一、为什么使用Pointer Network?

  传统的seq2seq模型是无法解决输出序列的词汇表会随着输入序列长度的改变而改变的问题的(解空间固定)。Pointer Network可以通过给输入序列的元素予一个指针,从而使得解空间不固定,可以解决OOV问题。总结来说,传统的seq2seq模型会要求有一个固定空间的大小,如果我们从不同两者之间做维度的切换(解空间发生变化时),就会出现OOV问题。

ConvexHullConvex \ HullConvexHull
如寻找凸包等。因为对于这类问题,输出往往是输入集合的子集。基于这种特点,作者考虑能不能找到一种结构类似编程语言中的指针,每个指针对应输入序列的一个元素, 从而我们可以直接操作输入序列而不需要特意设定输出词汇表 。
  在NLG场景,当我们面对不同的输入维度时,我们利用固定词表或者原来的解空间不足够解决问题,我们可以利用Pointer Network来指出是原来的哪些输入做的映射。

二、Pointer Network的结构

  传统的Attention结构为:

其中eje_jej是encoder的隐状态,而did_idi是decoder的隐状态,v,W1,W2v,W_1,W_2v,W1,W2都是可学习的参数

由传统Attention公式可以得到,以Decoder层的第一个隐状态(标的物)为例,对于Encoder层的隐状态都有一个权重aj1a_j^1aj1,指针指向权重最大的点即会把权重最大的点作为当前的输出,可以将这个输出作为Decoder中下一个神经元的输入,这就是为Pointer Network。

  在Decoder层(LSTM/RNN)中,会将当前的输出作为下一层的输入,能够保证模型学到序列的特征,同时,又能够保证输出是来自于全集的。Pointer Network网络解决:无论维度有多大,我都能够解决你的问题,因为根本没有建立词表,而利用attention的思想在你的网络里来指一个输入(基于attention在输入中选择一个概率最大的词),假设输入是10000个,就可以指这10000个中之一的词。

1.Pointer Network网络如何解决OOV问题

  假设传统的seq2seq模型,在之前的场景中,词典大小为50维,如果来到一个新的场景,输入的词典大小为10000维,那么在softmax中剩余的9500个词都是OOV。
  假设使用的是Pointer Network网络,输入时10000维,每次只需要在输入中找就可以,不再建立词典,也不需要做softmax从而映射到OOV。
  总结来说,传统seq2seq模型是从词表里面挑,需要做softmax;而Pointer Network网络是从输入中挑,不需要做softmax。

三、如何通过结合Pointer Network处理语言生成?

   上面已经介绍了Pointer Network ,那么如何通过结合Pointer Network处理语言生成呢?Language Model是自由,灵活,不可控;Pointer Net 是相对可控,信息来源于输入的信息范围;Pointer Net是天生的复制粘贴利器,这是一种抽取的方式。抽取式比较死板,所以,我们可以利用抽取与language Model结合到一起得到更好的生成的网络。
  下面通过《Get To The Point: Summarization with Pointer-Generator Networks》这篇文章来学习如何通过结合Pointer Network处理语言生成的。
  传统的seq2seq+Attention网络结构如下:

Baselineseq2seq+AttentionBaseline \ seq2seq+AttentionBaselineseq2seq+Attention
Decoder层的输出为词表大小的维度,输出为经softmax后概率最大的词。
  下面综合Pointer Network与language generation Model

  • language generation Model
    生成时会有Context Vector,然后将Context Vector投影到Vocabulary distribution上面去
  • Pointer Network
    在上面的Pointer Network中,我们选取的是attention weight的最大值对应的输入作为输出;在这篇论文中,我们选取的是attention distribution。


综合Pointer Network与language generation Model的关键是将Vocabulary distribution与attention distribution进行加权平均,权重分别为1−pgen与pgen1-p_{gen}与p_{gen}1pgenpgen

在每一个time step中,权重是不一样的。

这种情况下,无论输入中有多少OOV,都可以加到最终的Final Distribution中。例如:“2:0”。
注意事项:由于输入在字典中映射时,2:0仍然会被映射为OOV的indice,输入为OOV的词向量,但是会做记录,记录这是"2:0"的OOV。在Final Distribution中,“2:0"不需要indice,因为是直接从输入中复制的,可以直接知道是"2:0”。而不需要先知道index,再映射得到word。

  下面通过《Multi-Source Pointer Network for Product Title Summarization》这篇文章来学习如何通过结合Pointer Network处理语言生成的。现在,电商中经常会出现如下场景:利用商品名、商品描述生成商品title(一种summarization)。

计算步骤如下:

  • LSTM encode the knowledge and title
  • initialize
  • Attention
  • decoder output

  • 最终的分布:

    其中,λ\lambdaλ为可学习参数
  • 损失函数:

注意事项:两个Attention Distribution都不是固定词典大小的分布,相当于从两个词典中抽取组成新的组合。
  下面学习CopyNet网络。在某些场景下,是不能对Name、Place、Organization等这类情况进行翻译或回复的,

如上图,由于不能对某句话进行理解,机器会回复what do you mean by 某句话?针对这样的场景,就会用到CopyNet思想。下面有这样一个图,利用基于Pointer Network的CopyNet与language generation思想,


来自于language generation Model与CopyNet的结合。
  下面对XXXVVV的进行分析

图中XXX部分是输入序列的词汇集合,而VVV部分代表输出词汇集合,当然, 通常情况下两个集合都存在交集,而在X和V并集之外的部分就是未知词 汇:UNKUNKUNK

  • 当某个词是输入序列独有的,则该词的生成概率为0,复制概率不变;
  • 若某个词是输出词汇表独有的,则该词的复制概率为0,而生成概率不变;
  • 若某个词既存在于输入序列又存在于输出词汇表,则生成概率和复制概率都不变。 最后,将生成概率和复制概率加和得到最终的概率。

总结:
  这三篇文章的中心思想都是不在利用词典,而是利用指针的方式来指出一个值,然后加在一起得到Final Distribution。

1.语言生成中的生成重复词的问题的解决办法

  • 维持一个词典,词典里面装出现过的词的频率,对于出现过的词给予惩罚,这是针对decoding的多样性的优化
  • 训练过程中,将重复问题纳入loss function,加入重复的惩罚,比如:训练模型N次,基于现在的模型做预测时,预测结果是计算重复概率,然后利用它来更新模型。

四、PGN网络代码实现

  从上面,我们理解到:

  • 需要Pointer Network网络的根本原因在于我们面临着不同的解空间时,解空间的变化差异导致我们在Decoder中经softmax出现大量OOV,并且在特定场景下,对于Encoder中的OOV或特殊字符,我们认为是非常重要的。所以,我们需要利用Pointer Network中指针的思想,结合Attention Distribution与Vocabulary Distribution得到Final Distribution,相当于可以从输入中挑选词,就解决了OOV的问题。
  • 这部分网络包含几部分:一部分是attention model(Pointer Network),另一部分是语言模型,同时,需要权重PgenP_{gen}Pgen来决定Pointer Network与语言模型的weight,最后,将Coverage mechanism加入到模型中解决词被得到更多关注的问题。

下面代码是基于论文《Get To The Point: Summarization with Pointer-Generator Networks》的实现:GitHub上面已经开源,PGN开源实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from data_util import config
from numpy import randomclass Encoder(nn.Module):"""Encoder部分利用了embedding,lstm以及线性变化"""def __init__(self):super(Encoder, self).__init__()# embeddingself.embedding = nn.Embedding(config.vocab_size, config.emb_dim)init_wt_normal(self.embedding.weight)# lstmself.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)init_lstm_wt(self.lstm)# 线性self.W_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2, bias=False)#seq_lens should be in descending orderdef forward(self, input, seq_lens):embedded = self.embedding(input)# 使用rnn、lstm时,一般需要使用pack_padded_sequencepacked = pack_padded_sequence(embedded, seq_lens, batch_first=True)output, hidden = self.lstm(packed)# 使用完后需要解包encoder_outputs, _ = pad_packed_sequence(output, batch_first=True)  # h dim = B x t_k x nencoder_outputs = encoder_outputs.contiguous()encoder_feature = encoder_outputs.view(-1, 2*config.hidden_dim)  # B * t_k x 2*hidden_dimencoder_feature = self.W_h(encoder_feature)return encoder_outputs, encoder_feature, hiddenclass ReduceState(nn.Module):"""线性变换与非线性变化部分"""def __init__(self):super(ReduceState, self).__init__()self.reduce_h = nn.Linear(config.hidden_dim * 2, config.hidden_dim)init_linear_wt(self.reduce_h)self.reduce_c = nn.Linear(config.hidden_dim * 2, config.hidden_dim)init_linear_wt(self.reduce_c)def forward(self, hidden):h, c = hidden # h, c dim = 2 x b x hidden_dimh_in = h.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)hidden_reduced_h = F.relu(self.reduce_h(h_in))c_in = c.transpose(0, 1).contiguous().view(-1, config.hidden_dim * 2)hidden_reduced_c = F.relu(self.reduce_c(c_in))return (hidden_reduced_h.unsqueeze(0), hidden_reduced_c.unsqueeze(0)) # h, c dim = 1 x b x hidden_dimclass Attention(nn.Module):"""attention"""def __init__(self):super(Attention, self).__init__()# 涉及到coverageif config.is_coverage:self.W_c = nn.Linear(1, config.hidden_dim * 2, bias=False)self.decode_proj = nn.Linear(config.hidden_dim * 2, config.hidden_dim * 2)self.v = nn.Linear(config.hidden_dim * 2, 1, bias=False)def forward(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage):b, t_k, n = list(encoder_outputs.size())dec_fea = self.decode_proj(s_t_hat) # B x 2*hidden_dimdec_fea_expanded = dec_fea.unsqueeze(1).expand(b, t_k, n).contiguous() # B x t_k x 2*hidden_dimdec_fea_expanded = dec_fea_expanded.view(-1, n)  # B * t_k x 2*hidden_dimatt_features = encoder_feature + dec_fea_expanded # B * t_k x 2*hidden_dim# 如果用到coverage,会作为attention的额外输入部分if config.is_coverage:coverage_input = coverage.view(-1, 1)  # B * t_k x 1coverage_feature = self.W_c(coverage_input)  # B * t_k x 2*hidden_dimatt_features = att_features + coverage_featuree = F.tanh(att_features) # B * t_k x 2*hidden_dimscores = self.v(e)  # B * t_k x 1scores = scores.view(-1, t_k)  # B x t_kattn_dist_ = F.softmax(scores, dim=1)*enc_padding_mask # B x t_knormalization_factor = attn_dist_.sum(1, keepdim=True)attn_dist = attn_dist_ / normalization_factorattn_dist = attn_dist.unsqueeze(1)  # B x 1 x t_kc_t = torch.bmm(attn_dist, encoder_outputs)  # B x 1 x nc_t = c_t.view(-1, config.hidden_dim * 2)  # B x 2*hidden_dimattn_dist = attn_dist.view(-1, t_k)  # B x t_kif config.is_coverage:coverage = coverage.view(-1, t_k)coverage = coverage + attn_distreturn c_t, attn_dist, coverageclass Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.attention_network = Attention()# decoderself.embedding = nn.Embedding(config.vocab_size, config.emb_dim)init_wt_normal(self.embedding.weight)self.x_context = nn.Linear(config.hidden_dim * 2 + config.emb_dim, config.emb_dim)self.lstm = nn.LSTM(config.emb_dim, config.hidden_dim, num_layers=1, batch_first=True, bidirectional=False)init_lstm_wt(self.lstm)if config.pointer_gen:self.p_gen_linear = nn.Linear(config.hidden_dim * 4 + config.emb_dim, 1)#p_vocabself.out1 = nn.Linear(config.hidden_dim * 3, config.hidden_dim)self.out2 = nn.Linear(config.hidden_dim, config.vocab_size)init_linear_wt(self.out2)def forward(self, y_t_1, s_t_1, encoder_outputs, encoder_feature, enc_padding_mask,c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, step):if not self.training and step == 0:h_decoder, c_decoder = s_t_1s_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),c_decoder.view(-1, config.hidden_dim)), 1)  # B x 2*hidden_dimc_t, _, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature,enc_padding_mask, coverage)coverage = coverage_nexty_t_1_embd = self.embedding(y_t_1)x = self.x_context(torch.cat((c_t_1, y_t_1_embd), 1))lstm_out, s_t = self.lstm(x.unsqueeze(1), s_t_1)h_decoder, c_decoder = s_ts_t_hat = torch.cat((h_decoder.view(-1, config.hidden_dim),c_decoder.view(-1, config.hidden_dim)), 1)  # B x 2*hidden_dimc_t, attn_dist, coverage_next = self.attention_network(s_t_hat, encoder_outputs, encoder_feature,enc_padding_mask, coverage)if self.training or step > 0:coverage = coverage_nextp_gen = Noneif config.pointer_gen:p_gen_input = torch.cat((c_t, s_t_hat, x), 1)  # B x (2*2*hidden_dim + emb_dim)p_gen = self.p_gen_linear(p_gen_input)p_gen = F.sigmoid(p_gen)output = torch.cat((lstm_out.view(-1, config.hidden_dim), c_t), 1) # B x hidden_dim * 3output = self.out1(output) # B x hidden_dim#output = F.relu(output)output = self.out2(output) # B x vocab_sizevocab_dist = F.softmax(output, dim=1)if config.pointer_gen:vocab_dist_ = p_gen * vocab_distattn_dist_ = (1 - p_gen) * attn_distif extra_zeros is not None:vocab_dist_ = torch.cat([vocab_dist_, extra_zeros], 1)final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)else:final_dist = vocab_distreturn final_dist, s_t, c_t, attn_dist, p_gen, coverageclass Model(object):def __init__(self, model_file_path=None, is_eval=False):encoder = Encoder()decoder = Decoder()reduce_state = ReduceState()# shared the embedding between encoder and decoderdecoder.embedding.weight = encoder.embedding.weightif is_eval:encoder = encoder.eval()decoder = decoder.eval()reduce_state = reduce_state.eval()if use_cuda:encoder = encoder.cuda()decoder = decoder.cuda()reduce_state = reduce_state.cuda()self.encoder = encoderself.decoder = decoderself.reduce_state = reduce_stateif model_file_path is not None:state = torch.load(model_file_path, map_location= lambda storage, location: storage)self.encoder.load_state_dict(state['encoder_state_dict'])self.decoder.load_state_dict(state['decoder_state_dict'], strict=False)self.reduce_state.load_state_dict(state['reduce_state_dict'])

这里面涉及到OOV的处理如下:

def article2ids(article_words, vocab):"""输入的词(source)构建与id的映射,包括OOV词"""ids = []oovs = []unk_id = vocab.word2id(UNKNOWN_TOKEN)for w in article_words:i = vocab.word2id(w)if i == unk_id: # If w is OOVif w not in oovs: # Add to list of OOVsoovs.append(w)oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...else:ids.append(i)return ids, oovsdef abstract2ids(abstract_words, vocab, article_oovs):"""生成的词构建与id的映射,包括OOV词"""ids = []unk_id = vocab.word2id(UNKNOWN_TOKEN)for w in abstract_words:i = vocab.word2id(w)if i == unk_id: # If w is an OOV wordif w in article_oovs: # If w is an in-article OOVvocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV numberids.append(vocab_idx)else: # If w is an out-of-article OOVids.append(unk_id) # Map to the UNK token idelse:ids.append(i)return ids

比如:原先词表的维度为10000维,“2:0”属于原先词表的UNKNOWN_TOKEN即OOV词,需要将“2:0”添加到oovs表中,此时,“2:0”的index为词表的维度加上“2:0”在oovs表的索引。这里的“2:0”是临时存储,而不是一直储存。意思就是每一个句子有一个专门的oovs表,不同的句子构成的oovs表可能词相同但索引不同。


如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


NLP学习—14.Pointer Generator Network(指针)及代码实现相关推荐

  1. javaScript基础学习 - 14 - JavaScript内置对象 -案例代码

    javaScript基础学习 - 14 - JavaScript内置对象 -案例代码 1. Math对象最大值 2. 封装自己的数学对象 3. Math绝对值和三个取整方法 4. Math对象随机数方 ...

  2. 文本生成任务之营销文本生成(Seq2seq+attention、Pointer Generator Network、Converage、Beam Search、优化技巧、文本增强)

    文章目录 引言 项目任务简介 0. 数据预处理 0.1 将json文件转化成txt文件 0.2 词典处理 0.3 自定义数据集SampleDataset(Dataset类) 0.4 生成Dataloa ...

  3. Java学习-14 CSS与CSS3美化页面及网页布局

    Java学习-14 CSS与CSS3美化页面及网页布局 1. CSS简介 什么是CSS? CSS 指层叠样式表 (Cascading Style Sheets) 样式定义如何显示控制 HTML 元素, ...

  4. C++“准”标准库Boost学习指南(1):智能指针Boost.smart_ptr

    我们学习C++都知道智能指针,例如STL中的std::auto_ptr,但是为什么要使用智能指针,使用它能带给我们什么好处呢? 最简单的使用智能指针可以不会因为忘记delete指针而造成内存泄露.还有 ...

  5. NLP学习实践天池新人赛打卡第一天

    NLP学习实践天池新人赛打卡第一天 Task1 赛题理解 学习目标 赛题数据 数据标签 评测指标 数据读取 解题思路 Task1 赛题理解 赛题名称:零基础入门NLP之新闻文本分类 赛题目标:通过这道 ...

  6. NLP学习-Task 3: 子词模型Subword Models

    NLP学习 更新流程↓ Task 1: 简介和词向量Word Vectors Task 2: 词向量和词义Word Senses Task 3: 子词模型Subword Models Task 4: ...

  7. 天池NLP学习赛(1)赛题理解

    天池NLP学习赛(1)赛题理解 题目 题目类型:新闻文本分类(字符识别问题)链接 数据: 赛题数据为新闻文本,并按照字符级别进行匿名处理,数字编码形式呈现.整合划分出14个候选分类类别:财经.彩票.房 ...

  8. 深度学习-14:知名的深度学习开源架构和项目

    深度学习-14:知名的深度学习开源架构和项目 深度学习原理与实践(开源图书)-总目录 人工智能artificial intelligence,AI是科技研究中最热门的方向之一.像IBM.谷歌.微软.F ...

  9. 利用计算机技术实现对文本篇章,自然语言处理NLP学习笔记一:概念与模型初探...

    前言 先来看一些demo,来一些直观的了解. 自然语言处理: 可以做中文分词,词性分析,文本摘要等,为后面的知识图谱做准备. 知识图谱: 还有2个实际应用的例子,加深对NLP的理解 九歌机器人: 微软 ...

  10. 曝!BAT大厂NLP学习进阶之法~

    "语言理解是人工智能领域皇冠上的明珠." --比尔盖茨 自然语言处理是一门综合性的学问,它远远不止机器学习算法.相比图像或语音,文本的变化更加复杂,例如从预处理来看,NLP 就要求 ...

最新文章

  1. Java中测长函数_Core Java测试题
  2. 知识图谱(二)——知识表示
  3. linux下安装 ping 命令
  4. ASP.net 连接interbase数据库
  5. 停止MySQL正在执行的SQL语句
  6. 撤销工作表保护原密码_批量解除工作表保护,和批量执行保护一样简单
  7. 广数系统加工中心编程_编程十五年,谈谈对加工中心编程的一些看法...
  8. crm客户管理软件的精髓
  9. l440加装固态硬盘ngff_[转载]Thinkpad E431装NGFF固态硬盘图文详解
  10. 审批流程展示html,审批流程图怎么绘制?不懂可以看这里
  11. 弘扬时代新风建设网络文明,小趣带你揭秘肾透明细胞癌致瘤机制
  12. 服务器金属外壳刮花了怎么修复,pc拉杆箱被磨了怎么办?3方法快速修复(附防刮方式)...
  13. Xftp的介绍及下载安装教程
  14. 常用的RAID模式及特点
  15. 2013年中国城市及省份GDP排名
  16. OXFeeeFeee指针的含义
  17. 三菱FX MOV k2m0 k2y00 指令
  18. linux实时realtime,康佳特与OSADL携手优化 Real-Time Linux 的支持 顺利实现硬实时
  19. Windows Server 2003上搭建FTP服务器(IIS同理)
  20. 华为od统一考试B卷【阿里巴巴找黄金宝箱】Python 实现

热门文章

  1. C# Windows服务自动安装与注册
  2. STM32:配置定时器为PWM输出模式以及编码器接口模式
  3. 20190901 On Java8 第十五章 异常
  4. 解决用root用户及密码可以直接登陆某LINUX系统,但是用ssh登陆,系统却总是提示密码不对...
  5. 又是一个秋天~~~~
  6. 深度学习模型的可视化技术总结
  7. Linux_Shell符号及各种解释对照表
  8. 查询记录rs.previous()使用
  9. table 谷歌下不出现滚动条
  10. StringJoiner 拯救那些性能低下的字符串拼装代码(转)