Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了。这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码examples里的文本分类任务run_classifier。

关于源代码可以在huggingface的github中找到。

huggingface/pytorch-pretrained-BERT​github.com

在前三篇文章中我分别介绍了数据预处理部分和部分的模型:

周剑:一起读Bert文本分类代码 (pytorch篇 一)​zhuanlan.zhihu.com

周剑:一起读Bert文本分类代码 (pytorch篇 二)​zhuanlan.zhihu.com

周剑:一起读Bert文本分类代码 (pytorch篇 三)​zhuanlan.zhihu.com

我们可以看到BertForSequenceClassification类中调用关系如下图所示。本篇文章中,我会和大家一起读BertEncoder类中调用的BertLayer,BertAttention,BertSelfAttention和BertSelfOutput这几个类的代码。

打开pytorch_pretrained_bert.modeling.py,找到BertLayer类,代码如下:

class BertLayer(nn.Module):def __init__(self, config):super(BertLayer, self).__init__()self.attention = BertAttention(config)self.intermediate = BertIntermediate(config)self.output = BertOutput(config)def forward(self, hidden_states, attention_mask):attention_output = self.attention(hidden_states, attention_mask)intermediate_output = self.intermediate(attention_output)layer_output = self.output(intermediate_output, attention_output)return layer_output

从forward开始看,依次进入BertAttention,BertIntermediate和BertOutput这三个类。

我们先找到BertAttention这个类,代码如下:

class BertAttention(nn.Module):def __init__(self, config):super(BertAttention, self).__init__()self.self = BertSelfAttention(config)self.output = BertSelfOutput(config)def forward(self, input_tensor, attention_mask):self_output = self.self(input_tensor, attention_mask)attention_output = self.output(self_output, input_tensor)return attention_output

可以看到BertAttention类是由BertSelfAttention和BertSelfOutput组成的。

我们再找到BertSelfAttention这个类,代码如下:

class BertSelfAttention(nn.Module):def __init__(self, config):super(BertSelfAttention, self).__init__()if config.hidden_size % config.num_attention_heads != 0:raise ValueError("The hidden size (%d) is not a multiple of the number of attention ""heads (%d)" % (config.hidden_size, config.num_attention_heads))self.num_attention_heads = config.num_attention_headsself.attention_head_size = int(config.hidden_size / config.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = nn.Linear(config.hidden_size, self.all_head_size)self.key = nn.Linear(config.hidden_size, self.all_head_size)self.value = nn.Linear(config.hidden_size, self.all_head_size)self.dropout = nn.Dropout(config.attention_probs_dropout_prob)def transpose_for_scores(self, x):new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(*new_x_shape)return x.permute(0, 2, 1, 3)def forward(self, hidden_states, attention_mask):mixed_query_layer = self.query(hidden_states)mixed_key_layer = self.key(hidden_states)mixed_value_layer = self.value(hidden_states)query_layer = self.transpose_for_scores(mixed_query_layer)key_layer = self.transpose_for_scores(mixed_key_layer)value_layer = self.transpose_for_scores(mixed_value_layer)# Take the dot product between "query" and "key" to get the raw attention scores.attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_scores = attention_scores / math.sqrt(self.attention_head_size)# Apply the attention mask is (precomputed for all layers in BertModel forward() function)attention_scores = attention_scores + attention_mask# Normalize the attention scores to probabilities.attention_probs = nn.Softmax(dim=-1)(attention_scores)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.dropout(attention_probs)context_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)return context_layer

可以看到,BertSelfAttention这个类终于有点真东西了。

从forward开始看。首先是query_layer,key_layer和value_layer分别是三个线形Linear层,对应进入Multi-Head Attention。下图是Transformer的encoder模型,来源于(Attention Is All You Need)这篇论文。

从图中可以看到query_layer,key_layer和value_layer三层进入Multi-Head Attention。而Multi-Head Attention内部如下图:

Multi-Head Attention内部的Scaled Dot-Product Attention结构如下图。

因此,我们可以看到BertSelfAttention类中如下代码是计算Scaled Dot-Product Attention的。

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_scores = attention_scores / math.sqrt(self.attention_head_size)# Apply the attention mask is (precomputed for all layers in BertModel forward() function)attention_scores = attention_scores + attention_mask# Normalize the attention scores to probabilities.attention_probs = nn.Softmax(dim=-1)(attention_scores)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.dropout(attention_probs)context_layer = torch.matmul(attention_probs, value_layer)

再接着BertSelfAttention的forward继续看。剩下下的主要是contact和tensor的shape调整。解释一下其中的一些tensor的函数。

tensor.permute()是shape位置交换函数,例如一个tensor的shape是tensor[(3, 5, 6)], tensor.permute(0, 2, 1)后,shape变为tensor[(3, 6, 5)].

contiguous:view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。

在pytorch 0.4.0版本新添加了reshape函数,类似于numpy.reshape()。它大致相当于 tensor.contiguous().view().

关于tensor.view()的解释官方文档如下:

这样我们就读完了BertSelfAttention这个类,我们接下来看BertSelfOutput这个类,它的代码如下:

class BertSelfOutput(nn.Module):def __init__(self, config):super(BertSelfOutput, self).__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)self.dropout = nn.Dropout(config.hidden_dropout_prob)def forward(self, hidden_states, input_tensor):hidden_states = self.dense(hidden_states)hidden_states = self.dropout(hidden_states)hidden_states = self.LayerNorm(hidden_states + input_tensor)return hidden_states

发现就是一个线形Linear层+dropout+一个LayerNorm。BertSelfAttention和BertSelfOutput,这也就是BertAttention这个类的全部。

下一篇文章中我会带着大家继续读BertLayer类中的BertIntermediate和BertOutput类。

周剑:一起读Bert文本分类代码 (pytorch篇 五)​zhuanlan.zhihu.com

周剑:一起读Bert文本分类代码 (pytorch篇 六)​zhuanlan.zhihu.com

pytorch bert文本分类_一起读Bert文本分类代码 (pytorch篇 四)相关推荐

  1. pytorch实现文本分类_使用变形金刚进行文本分类(Pytorch实现)

    pytorch实现文本分类 'Attention Is All You Need' "注意力就是你所需要的" New deep learning models are introd ...

  2. 易语言读文本内容_易读性如何使文本易于阅读

    易语言读文本内容 Your first step in making your texts legible is to understand what legibility means. It is ...

  3. pytorch自带网络_【方家之言】一篇长文学懂 pytorch

    作为目前越来越受欢迎的深度学习框架,pytorch 基本上成了新人进入深度学习领域最常用的框架.相比于 TensorFlow,pytorch 更易学,更快上手,也可以更容易的实现自己想要的 demo. ...

  4. python的控件text的文本属性_只需6行Python代码就给图片加上水印——你一看就会了...

    大家在做项目开发的过程中,会不会经常遇到需要处理图片却没有快速有效的工具的情况呢?比如客户需要给图片加上水印,你可能会用到PS这些高级软件去处理,这样虽然有效果但是需要相对较长的时间:作为程序猿,你一 ...

  5. java如何创建一个文本框_创建一个有文本框和三个按钮的程序。当按下某个按钮时,使不同的文字(Java..._考试资料网...

    问答题创建一个有文本框和三个按钮的程序.当按下某个按钮时,使不同的文字("Java","编程","不难学")显示在文本框中.已经给出部分代码 ...

  6. 使用java怎么实现商品三级分类_如何实现列表三级分类---后端+前端

    对于分类来说,一般包括一级分类,二级分类,三级分类, 大部分网站都是左边点击二级分类,右边显示相对应商品 下面就来为大家详细分析一下该如何实现吧. 如图: 分析图 1.1后端实现:JavaBean 与 ...

  7. excel分类_最简单的Excel分类汇总教程!三分钟包学包会!

    在进行数据统计的时候我们都会经常用excel来完成,特别是数量较多较复杂的数据,通过Excel的分类汇总能够用方便快捷的处理,那么Excel的分类汇总功能到底是如何的呢?它有哪些功能?如何操作?别急, ...

  8. python提取关键词分类_用Py做文本分析5:关键词提取

    1.关键词提取 关键词指的是原始文档的和核心信息,关键词提取在文本聚类.分类.自动摘要等领域中有着重要的作用. 针对一篇语段,在不加人工干预的情况下提取出其关键词 首先进行分词处理 关键词分配:事先给 ...

  9. python 文本分析_使用Python进行文本分析–书评

    python 文本分析 This is a book review of Text Analytics with Python: A Practical Real-World Approach to ...

最新文章

  1. 历史上最伟大的方程 (托尼·赖斯 著)
  2. 灵感编程:最大公约数算法解析
  3. KVM虚拟机搭建增量镜像(一个基本镜像拷贝成无数多个子镜像)
  4. 深度学习-Tensorflow1.x-CNN中的padding参数
  5. 斯伦贝谢好进吗_在斯伦贝谢工作是怎样的体验?
  6. matlab切割肿瘤算法,ML之RF:基于Matlab利用RF算法实现根据乳腺肿瘤特征向量高精度(better)预测肿瘤的是恶性还是良性...
  7. linux shell中实现循环日期的实例代码
  8. 2.2 Mnist手写数据集
  9. istio入门(01)istio的优势在哪里?
  10. [数论] 快速傅里叶变换FFT
  11. fortan dll在本地可以运行成功,移植到其他电脑上报错Exception in thread main java.lang.UnsatisfiedLinkError: 找不到指定的模块。
  12. 《未来世界的幸存者》摘录
  13. u8转完看不到菜单_进入软件后所有菜单栏都不显示
  14. 于IIS 7的HTTP 错误 500.0 - Internal Server Error(错误代码:0x800700.
  15. 字典的添加、更新、修改
  16. 秦皇岛科学选育新品种 国稻种芯·中国水稻节:河北谱丰收曲
  17. 近代自然科学为啥未诞生在中国----中国文化的欠缺
  18. com 如何新打开ac
  19. Composite 聚合——Elasticsearch 聚合后分页新实现
  20. 一、旋转矩阵,旋转向量,单位四元数的相互转换总结

热门文章

  1. Google团队在DNN的实际应用方式的整理
  2. 堆,栈,内存泄露,内存溢出介绍
  3. oracle省市表,省市之一 创建全国省市Sql表
  4. linux添加美式键盘,win8\win server 2012添加【中文--美式键盘】
  5. pycharm 运行控制台中文乱码解决办法
  6. python pandas 读写 csv 文件
  7. linux Address already in use 端口被占用解决办法
  8. microsoft edge 打不开 csdn 博客
  9. python的代码编译、代码打包方法
  10. 图像拼接--A multiresolution spline with application to image mosaics