相关背景

深度学习的encoder都是基于大规模的未标注数据,但是这些encoder是否完整利用了语料的所有信息,这是未被证实的。类似于Bert的这些预训练模型使用的是文本的最小单位——字。但是中文的最小单位并不是字,中文的语义和N-gram有很大的关系。

目前模型的缺陷

  1. 基于word masking,encoder只能学习到已有的词和句的信息
  2. 基于mask的方法在pre-train和fine-tune阶段mismatch。因为预训练过程中遮盖存在但是fine-tune阶段遮盖不存在。
  3. 错误的分词或实体识别会影响到encoder的通用能力

因此论文提出ZEN-基于N-gram的中文encoder

ZEN有以下特点

  1. 1引入N-gram编码方式,方便模型识别出可能的字的组合
  2. 虽然引入了N-gram但是encoder的输出还是按照Bert那样逐字输出不会影响下有任务。

ZEN的预训练过程基于中文维基百科训练,微调是基于其他下游的中文任务。

下面我们来具体了解一下ZEN

ZEN

N-Gram

1. N-gram的提取

N-gram的提取分为两步,第一步是根据现有语料基于频率生成N-gram词表Lexicon, 请注意这些N-gram可能是包含关系,例如里面同时存在的粤港澳港澳。第二步是根据此表生成训练数据的N-gram matrix,如下图所示。

N-gram matrix是一个

的矩阵,其中
是句子中包含的字数,
是句子可以提取的N-gram的数量。
表示第i个词是否属于第j个N-gram

这里N-Gram矩阵的生成非常朴素,代码位置examples.utils_sequence_level_tasks中, 在函数convert_examples_to_features中。这个函数主要是将输入的batch rokenize 之后转化成word id,以及label进行处理,同时对N-Gram进行编码。其他过程我们这里不再多说,主要看一下N-Gram矩阵这部分的逻辑。

# ----------- code for ngram BEGIN-----------

需要注意的ngram_dict是提前生成的,每一句话我们先遍历每一种组合,生成所有可能的ngram,并记录他们的长度和起始位置。ngram_positions_matrix就是我们需要的N-Gram matrix,他是一个max_seq_length*max_ngram_in_seq的矩阵,其中max_seq_length是输入的词的长度,max_ngram_in_seq是一个句子中最多的N-Gram组合的数量,默认是128,然后遍历赋值。需要注意当一个word被mask掉他的N-gram也不再考虑。

2. N-gram编码

N-gram encoder的结构如下图所示,文章中采用多层transformer结构来对N-gram进行编码,因为N-gram的顺序不需要考虑所以position encoding。N-gram encoder对于模型效率的提升是有很大影响的,为什么嘞,因为N-gram encoder可以学习到一些句子中重要的词组,从而提升模型的效率。这里面输入的N-gram embedding可以理解为Word embedding,

代码里N-Gram Embedding的编码方式也和Word Embedding相差不多。如下分别是ZEN的Word Emebedding和N-Gram Emebedding的生成方式。

class BertEmbeddings(nn.Module):"""Construct the embeddings from word, position and token_type embeddings."""def __init__(self, config):super(BertEmbeddings, self).__init__()self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load# any TensorFlow checkpoint fileself.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, input_ids, token_type_ids=None):seq_length = input_ids.size(1)position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)position_ids = position_ids.unsqueeze(0).expand_as(input_ids)if token_type_ids is None:token_type_ids = torch.zeros_like(input_ids)words_embeddings = self.word_embeddings(input_ids)position_embeddings = self.position_embeddings(position_ids)token_type_embeddings = self.token_type_embeddings(token_type_ids)embeddings = words_embeddings + position_embeddings + token_type_embeddingsembeddings = self.LayerNorm(embeddings)embeddings = self.dropout(embeddings)return embeddingsclass BertWordEmbeddings(nn.Module):"""Construct the embeddings from ngram, position and token_type embeddings."""def __init__(self, config):super(BertWordEmbeddings, self).__init__()self.word_embeddings = nn.Embedding(config.word_size, config.hidden_size, padding_idx=0)self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load# any TensorFlow checkpoint fileself.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, input_ids, token_type_ids=None):if token_type_ids is None:token_type_ids = torch.zeros_like(input_ids)words_embeddings = self.word_embeddings(input_ids)token_type_embeddings = self.token_type_embeddings(token_type_ids)embeddings = words_embeddings + token_type_embeddingsembeddings = self.LayerNorm(embeddings)embeddings = self.dropout(embeddings)return embeddings

3. N-gram进行预训练

模型结构如下所示。

ZEN模型将对字和其有关的N-gram进行编码,这个该如何结合呢,就是将矩阵相加。

  • 是character_encoder第l层输出的第i个character的hidden output
  • 是第l层和第i个character有关的第k个N-gram。需要注意的是这里一个字可以被包含到多个N-gram中,例如 粤港澳大湾区和港澳

那么对于第l层encoder这种增强可以表示为

  • 是这一层的embedding matrix
  • 是character-N-gram相关矩阵
  • M是N-gram matrix

需要注意的是如果这个字被masked掉了,那么这个字的N-gram就不会被加进去了。

ZEN Encoder的代码如下,其中hidden_states加上了N-Gram经过attention的结果。

class ZenEncoder(nn.Module):def __init__(self, config, output_attentions=False, keep_multihead_output=False):super(ZenEncoder, self).__init__()self.output_attentions = output_attentionslayer = BertLayer(config, output_attentions=output_attentions,keep_multihead_output=keep_multihead_output)self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])self.word_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_word_layers)])self.num_hidden_word_layers = config.num_hidden_word_layersdef forward(self, hidden_states, ngram_hidden_states, ngram_position_matrix, attention_mask,ngram_attention_mask,output_all_encoded_layers=True, head_mask=None):# Need to check what is the attention masking doing hereall_encoder_layers = []all_attentions = []num_hidden_ngram_layers = self.num_hidden_word_layersfor i, layer_module in enumerate(self.layer):hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])if i < num_hidden_ngram_layers:ngram_hidden_states = self.word_layers[i](ngram_hidden_states, ngram_attention_mask, head_mask[i])if self.output_attentions:ngram_attentions, ngram_hidden_states = ngram_hidden_statesif self.output_attentions:attentions, hidden_states = hidden_statesall_attentions.append(attentions)hidden_states += torch.bmm(ngram_position_matrix.float(), ngram_hidden_states.float())if output_all_encoded_layers:all_encoder_layers.append(hidden_states)if not output_all_encoded_layers:all_encoder_layers.append(hidden_states)if self.output_attentions:return all_attentions, all_encoder_layersreturn all_encoder_layers

实验结果

1. 实验设置

论文使用了中文wiki作为语料,并去除了标点符号,进行了简体转化,对英文字母统一转为小写的数据清洗。

N-gram词典是根据训练语料,对N-gram按照词频排序并设置阈值,频率低于阈值的N-gram将会被剔除。最终的N-gram包含17.9万~6.4万之间。N-gram embedding是随机初始化的,模型结构和Bert结构相同,采用12层12个muti-head attention结构,hidden size大小为768。预训练也和Bert相同采用MLM和NSP任务。

2. 实验效果

模型的实验效果如下图所示,其实R表示模型参数随机加载,P表示模型参数根据谷歌的Bert模型初始化,B表示用的是Bert Base,L表示Bert Large。可以看出ZEN在多个模型上取得了当前比较好的效果。

相关分析

文中还进行了一些分析。

1. 小规模语料上进行预训练

当前的预训练模型大都是在大型数据集上进行实验,对于部分领域大规模数据集很难收集,于是本文抽取了1/10大小的维基语料进行预训练,模型参数采取随机初始化。可以看出ZEN在小规模数据集上的效果要稍稍优于Bert。应该是因为N-gram对embedding进行了增强,这表示ZEN在小规模数据集的场景要优于Bert。

2. 收敛速度

下图展示了ZEN在CWS(Chinese word segmentation)和SA(Sentiment analysis)任务上的不同训练epoch的表现。可以看出相同的epochZEN的效果比Bert的更好,同事ZEN比Bert收敛更快。

3. N-gram Threshold

文中对我们提取N-gram频率的阈值进行了分析,发现阈值在10~20时候效果最好。同时论文对使用最多的N-gram的数量也进行了分析,发现随着N-gram数量的增多模型效果有了部分提升。

4. 热力图分析

论文对encoder的N-gram也进行了热力图分析,如下图所示,是两句话在1~7层中每个N-gram的weight。可以看出,“有意义”的N-gram所占的权重比“无意义”的N-gram权重要高,例如“提高”和“波士顿”比“会提高”和“士顿”的权重要高。这表ZEN会在N-gram中注重语义,选择比较合适的词组。同时我们发现较长的词组在比较高的层中获得权重比较大,这也表示这些比较长的词组对模型理解语句有比较重要的影响。

相关资料

ZEN: Pre-training Chinese Text Encoder Enhanced by N-gram Representations​arxiv.orgZEN-torch​github.com

gram矩阵_ZEN-基于N-gram的中文Encoder相关推荐

  1. Gram矩阵+Gram矩阵和协方差矩阵的关系

    目录 Gram矩阵简介 协方差矩阵 Gram矩阵 和 协方差矩阵的关系 Gram Matrix代码 Gram矩阵简介 gram矩阵是计算每个通道 i 的feature map与每个通道 j 的feat ...

  2. 风格迁移-风格损失函数(Gram矩阵)理解

    吴恩达教授卷积神经网络的课程中,有一部分是关于风格迁移的内容.组合内容图片,风格图片生成新的图片 主体思路是: 随机生成一张图片(可以基于原内容图片生成,从而加速训练) 计算其与内容图片之间的内容损失 ...

  3. 如何对batch的数据求Gram矩阵

    Gram矩阵概念和理解 在风格迁移中,我们要比较生成图片和风格图片的相似性,评判标准就是通过计算Gram矩阵得到的.关于Gram矩阵的定义,可以参考[1]. 由这个矩阵的样子,很容易就想到协方差矩阵. ...

  4. Gram 矩阵性质及应用

    v1,v2,-,vnv_1,v_2,\ldots,v_n 是内积空间的一组向量,Gram 矩阵定义为: Gij=⟨vi,vj⟩G_{ij}=\langle v_i,v_j\rangle,显然其是对称矩 ...

  5. (高能预警!)为什么Gram矩阵可以代表图像风格?带你揭开图像风格迁移的神秘面纱!

    文章目录 (高能预警)为什么Gram矩阵可以代表图像风格 简介 风格迁移概述 领域适应 相关知识 Gram矩阵 特征值分解 核函数 希尔伯特空间 可再生核希尔伯特空间 最大平均差异(MMD) 图像风格 ...

  6. Gram矩阵和核函数

    Gram矩阵定义 内积空间中的一组向量 v 1 , v 2 , ⋯   , v n \bm v_1,\bm v_2,\cdots,\bm v_n v1​,v2​,⋯,vn​的Gram矩阵是内积的Her ...

  7. Gram矩阵及其实际含义

    1.Gram矩阵的定义 2.意义 格拉姆矩阵可以看做feature之间的偏心协方差矩阵(即没有减去均值的协方差矩阵),在feature map中,每个数字都来自于一个特定滤波器在特定位置的卷积,因此每 ...

  8. Gram矩阵与卷积网络中的卷积的直观理解

    Gram矩阵其实是一种度量矩阵.矩阵分析中有这样的定义. 设 V V是nn维欧式空间 ϵ1,⋯,ϵn \mathbf{\epsilon_1, \cdots, \epsilon_n }是它的一个基, g ...

  9. gram矩阵的性质_矩阵分析(九)Gram矩阵

    欧氏空间 $V$是$\mathbb{R}$上的线性空间,定义映射 $$ \sigma: V\times V \to \mathbb{R} $$ 对于$\alpha, \beta \in V$,将$\s ...

最新文章

  1. Idea用maven给springboot打jar包
  2. ACMNO.17C语言-筛法求素数 用筛法求之N内的素数。
  3. IntelliJ IDEA 2018.1新特性
  4. linux 云主机 卡顿 排查过程
  5. 《Linux内核设计与实现》内存管理札记
  6. action与servlet用法区别
  7. QT的QSemaphoreReleaser类的使用
  8. Python 奇技淫巧
  9. Python数据结构与算法--数据类型
  10. query row php,php – 如何在Codeigniter上使用$query- row获取类对象
  11. BZOJ4358: permu(带撤销并查集 不删除莫队)
  12. 服务器系统分区 是啥,服务器系统盘分区
  13. 句句真研—每日长难句打卡Day7
  14. poj 1679(次小生成树)
  15. 怎么查看台式计算机网络密码,台式电脑怎么查看wifi密码_台式机如何看wifi密码?-192路由网...
  16. Codeforces 833D Red-Black Cobweb [点分治]
  17. 四棱锥和三棱锥重叠求面数
  18. XP系统封装-2011年
  19. C语言Printf格式大全(各种%输出形式)
  20. 分享125个ASP源码,总有一款适合你

热门文章

  1. linux shell合并文件命令paste
  2. Ubuntu中防火墙设置
  3. 人群计数--Mixture of Counting CNNs
  4. 去水印--《On the Effectiveness of Visible Watermarks》
  5. 行人检测--What Can Help Pedestrian Detection?
  6. 批量残差网络-Aggregated Residual Transformations for Deep Neural Networks
  7. jquery 实现仿QQ右下角弹出框
  8. 解决报错: MobaXterm X11 proxy: Unsupported authorisation protocol
  9. 电路非门_【连载】电路和维修基础之门电路、转换器
  10. C++显示转换、dynamic_cast重点