bert 模型参数统计
bert 模型参数量分析
使用huggingface transformers中的bert模型,分析统计模型的参数量
huggingface 模型导入
import torchfrom transformers import BertTokenizer, BertModelbertModel = BertModel.from_pretrained("bert-base-chinese", output_hidden_states=True, output_attentions=True)total = sum(p.numel() for p in bertModel.parameters())print("total param:",total)输出如下:total param: 102267648
上述代码统计了模型的总参数量,输出为102267684
下面从三个方面统计分析bert 模型参数量
1、embedding 层
bert中的embedding有三种,分别为word embedding、position embedding、sentence embedding。
在bert-base-chinese这个模型中,词汇数量为21128,embedding维度为768,每条数据长度L为512。
word embedding参数量:21128*768
position embedding参数量:512*768
sentence embedding参数量:2*768
在embedding层最后有Layer Norm 层,改层的参数量为768+768,LN公式中的α\alphaα 和 β\betaβ
embedding层中的参数为
21128*768+512*768+2*768+768+768 =16622592
2、self-attention层
self-attention 一共有12层,每层中有两部分组成,分别为multihead-Attention 和Layer Norm层
multihead-Attention 中有Q、K、V三个转化矩阵和一个拼接矩阵,Q、K、V的shape为:768*12*64 +768
第一个768为embedding维度,12为head数量,64为子head的维度,最后加的768为模型中的bias。经过Q、K、V变化后的数据需要concat起来,额外需要一个768*768+768的拼接矩阵。
Layer Norm参数量:768+768
self-attention一层中的参数为:
(768*12*64 +768)*3+768*768+768 +768+768=2363904
一共12层,2363904 *12 = 28366848
3、feedforward层
feedforward 一共有12层,每层中有两部分组成,分别为feedforward和Layer Norm层
feedforward 网络结构为W2(W1X+b1)+b2W_2(W_1X+b_1)+b_2W2(W1X+b1)+b2,有两个线性变换层W1W_1W1是从768–>7684,W2W_2W2是从7684–>768,W1W_1W1参数量为7687684+7684,W2W_2W2参数量为7684*768+768,
Layer Norm参数量:768+768
feedforward一层中的参数为:
(768*768*4 +768*4)+(768*4*768+768) + 768+768 =4723968
一共12层,4723968*12 = 56687616
参数总计
embedding:16622592
self-attention:28366848
feedforward:56687616
在feedforward层后还有一个pooler层,维度为768*768,参数量为(768*768+768 weights+bias),为获取训练数据中第一个特殊字符[CLS]的词向量,进一步计算bert中的NSP任务中的loss
total = 16622592 +28366848+56687616 + 768*768+768= 102267648
与pytorch统计结果相同。
上述有不明白的地方,可以看看bert模型中每层的参数
以下为模型中每一层的参数量:
for name,param in bertModel.named_parameters():print(name)print(param.shape)# 输出如下:embeddings.word_embeddings.weight
torch.Size([21128, 768])
embeddings.position_embeddings.weight
torch.Size([512, 768])
embeddings.token_type_embeddings.weight
torch.Size([2, 768])
embeddings.LayerNorm.weight
torch.Size([768])
embeddings.LayerNorm.bias
torch.Size([768])
encoder.layer.0.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias
torch.Size([768])
encoder.layer.0.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias
torch.Size([768])
encoder.layer.0.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias
torch.Size([768])
encoder.layer.0.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias
torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.0.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.0.intermediate.dense.bias
torch.Size([3072])
encoder.layer.0.output.dense.weight
torch.Size([768, 3072])
encoder.layer.0.output.dense.bias
torch.Size([768])
encoder.layer.0.output.LayerNorm.weight
torch.Size([768])
encoder.layer.0.output.LayerNorm.bias
torch.Size([768])
encoder.layer.1.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.1.attention.self.query.bias
torch.Size([768])
encoder.layer.1.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.1.attention.self.key.bias
torch.Size([768])
encoder.layer.1.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.1.attention.self.value.bias
torch.Size([768])
encoder.layer.1.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.1.attention.output.dense.bias
torch.Size([768])
encoder.layer.1.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.1.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.1.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.1.intermediate.dense.bias
torch.Size([3072])
encoder.layer.1.output.dense.weight
torch.Size([768, 3072])
encoder.layer.1.output.dense.bias
torch.Size([768])
encoder.layer.1.output.LayerNorm.weight
torch.Size([768])
encoder.layer.1.output.LayerNorm.bias
torch.Size([768])
encoder.layer.2.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.2.attention.self.query.bias
torch.Size([768])
encoder.layer.2.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.2.attention.self.key.bias
torch.Size([768])
encoder.layer.2.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.2.attention.self.value.bias
torch.Size([768])
encoder.layer.2.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.2.attention.output.dense.bias
torch.Size([768])
encoder.layer.2.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.2.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.2.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.2.intermediate.dense.bias
torch.Size([3072])
encoder.layer.2.output.dense.weight
torch.Size([768, 3072])
encoder.layer.2.output.dense.bias
torch.Size([768])
encoder.layer.2.output.LayerNorm.weight
torch.Size([768])
encoder.layer.2.output.LayerNorm.bias
torch.Size([768])
encoder.layer.3.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.3.attention.self.query.bias
torch.Size([768])
encoder.layer.3.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.3.attention.self.key.bias
torch.Size([768])
encoder.layer.3.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.3.attention.self.value.bias
torch.Size([768])
encoder.layer.3.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.3.attention.output.dense.bias
torch.Size([768])
encoder.layer.3.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.3.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.3.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.3.intermediate.dense.bias
torch.Size([3072])
encoder.layer.3.output.dense.weight
torch.Size([768, 3072])
encoder.layer.3.output.dense.bias
torch.Size([768])
encoder.layer.3.output.LayerNorm.weight
torch.Size([768])
encoder.layer.3.output.LayerNorm.bias
torch.Size([768])
encoder.layer.4.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.4.attention.self.query.bias
torch.Size([768])
encoder.layer.4.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.4.attention.self.key.bias
torch.Size([768])
encoder.layer.4.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.4.attention.self.value.bias
torch.Size([768])
encoder.layer.4.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.4.attention.output.dense.bias
torch.Size([768])
encoder.layer.4.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.4.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.4.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.4.intermediate.dense.bias
torch.Size([3072])
encoder.layer.4.output.dense.weight
torch.Size([768, 3072])
encoder.layer.4.output.dense.bias
torch.Size([768])
encoder.layer.4.output.LayerNorm.weight
torch.Size([768])
encoder.layer.4.output.LayerNorm.bias
torch.Size([768])
encoder.layer.5.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.5.attention.self.query.bias
torch.Size([768])
encoder.layer.5.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.5.attention.self.key.bias
torch.Size([768])
encoder.layer.5.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.5.attention.self.value.bias
torch.Size([768])
encoder.layer.5.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.5.attention.output.dense.bias
torch.Size([768])
encoder.layer.5.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.5.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.5.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.5.intermediate.dense.bias
torch.Size([3072])
encoder.layer.5.output.dense.weight
torch.Size([768, 3072])
encoder.layer.5.output.dense.bias
torch.Size([768])
encoder.layer.5.output.LayerNorm.weight
torch.Size([768])
encoder.layer.5.output.LayerNorm.bias
torch.Size([768])
encoder.layer.6.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.6.attention.self.query.bias
torch.Size([768])
encoder.layer.6.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.6.attention.self.key.bias
torch.Size([768])
encoder.layer.6.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.6.attention.self.value.bias
torch.Size([768])
encoder.layer.6.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.6.attention.output.dense.bias
torch.Size([768])
encoder.layer.6.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.6.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.6.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.6.intermediate.dense.bias
torch.Size([3072])
encoder.layer.6.output.dense.weight
torch.Size([768, 3072])
encoder.layer.6.output.dense.bias
torch.Size([768])
encoder.layer.6.output.LayerNorm.weight
torch.Size([768])
encoder.layer.6.output.LayerNorm.bias
torch.Size([768])
encoder.layer.7.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.7.attention.self.query.bias
torch.Size([768])
encoder.layer.7.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.7.attention.self.key.bias
torch.Size([768])
encoder.layer.7.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.7.attention.self.value.bias
torch.Size([768])
encoder.layer.7.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.7.attention.output.dense.bias
torch.Size([768])
encoder.layer.7.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.7.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.7.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.7.intermediate.dense.bias
torch.Size([3072])
encoder.layer.7.output.dense.weight
torch.Size([768, 3072])
encoder.layer.7.output.dense.bias
torch.Size([768])
encoder.layer.7.output.LayerNorm.weight
torch.Size([768])
encoder.layer.7.output.LayerNorm.bias
torch.Size([768])
encoder.layer.8.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.8.attention.self.query.bias
torch.Size([768])
encoder.layer.8.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.8.attention.self.key.bias
torch.Size([768])
encoder.layer.8.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.8.attention.self.value.bias
torch.Size([768])
encoder.layer.8.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.8.attention.output.dense.bias
torch.Size([768])
encoder.layer.8.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.8.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.8.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.8.intermediate.dense.bias
torch.Size([3072])
encoder.layer.8.output.dense.weight
torch.Size([768, 3072])
encoder.layer.8.output.dense.bias
torch.Size([768])
encoder.layer.8.output.LayerNorm.weight
torch.Size([768])
encoder.layer.8.output.LayerNorm.bias
torch.Size([768])
encoder.layer.9.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.9.attention.self.query.bias
torch.Size([768])
encoder.layer.9.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.9.attention.self.key.bias
torch.Size([768])
encoder.layer.9.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.9.attention.self.value.bias
torch.Size([768])
encoder.layer.9.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.9.attention.output.dense.bias
torch.Size([768])
encoder.layer.9.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.9.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.9.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.9.intermediate.dense.bias
torch.Size([3072])
encoder.layer.9.output.dense.weight
torch.Size([768, 3072])
encoder.layer.9.output.dense.bias
torch.Size([768])
encoder.layer.9.output.LayerNorm.weight
torch.Size([768])
encoder.layer.9.output.LayerNorm.bias
torch.Size([768])
encoder.layer.10.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.10.attention.self.query.bias
torch.Size([768])
encoder.layer.10.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.10.attention.self.key.bias
torch.Size([768])
encoder.layer.10.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.10.attention.self.value.bias
torch.Size([768])
encoder.layer.10.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.10.attention.output.dense.bias
torch.Size([768])
encoder.layer.10.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.10.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.10.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.10.intermediate.dense.bias
torch.Size([3072])
encoder.layer.10.output.dense.weight
torch.Size([768, 3072])
encoder.layer.10.output.dense.bias
torch.Size([768])
encoder.layer.10.output.LayerNorm.weight
torch.Size([768])
encoder.layer.10.output.LayerNorm.bias
torch.Size([768])
encoder.layer.11.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.11.attention.self.query.bias
torch.Size([768])
encoder.layer.11.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.11.attention.self.key.bias
torch.Size([768])
encoder.layer.11.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.11.attention.self.value.bias
torch.Size([768])
encoder.layer.11.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.11.attention.output.dense.bias
torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.11.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.11.intermediate.dense.bias
torch.Size([3072])
encoder.layer.11.output.dense.weight
torch.Size([768, 3072])
encoder.layer.11.output.dense.bias
torch.Size([768])
encoder.layer.11.output.LayerNorm.weight
torch.Size([768])
encoder.layer.11.output.LayerNorm.bias
torch.Size([768])
pooler.dense.weight
torch.Size([768, 768])
pooler.dense.bias
torch.Size([768])
bert 模型参数统计相关推荐
- 关于Bert模型参数的分布
参数分布 Bert模型的版本如下: BERT-Base, Uncased: 12-layer, 768-hidden, 12-heads, 110M parameters BERT-Large, Un ...
- bert模型及其应用场景分享
文章目录 1. Transformer优缺点: 2. 序列标注任务中为什么还要lstm 3.模型融合 3.1字词向量结合 3.2支持mask的最大池化 3.3支持mask的平均池化 3.4 Bert ...
- bert模型蒸馏实战
由于bert模型参数很大,在用到生产环境中推理效率难以满足要求,因此经常需要将模型进行压缩.常用的模型压缩的方法有剪枝.蒸馏和量化等方法.比较容易实现的方法为知识蒸馏,下面便介绍如何将bert模型进行 ...
- NLP突破性成果 BERT 模型详细解读 bert参数微调
https://zhuanlan.zhihu.com/p/46997268 NLP突破性成果 BERT 模型详细解读 章鱼小丸子 不懂算法的产品经理不是好的程序员 关注她 82 人赞了该文章 Goo ...
- 性能媲美BERT,参数量仅为1/300,谷歌最新的NLP模型
在最新的博客文章中,谷歌公布了一个新的 NLP 模型,在文本分类任务上可以达到 BERT 级别的性能,但参数量仅为 BERT 的 1/300. 在过去的十年中,深度神经网络从根本上变革了自然语言处理( ...
- R语言使用glm函数构建拟泊松回归模型(quasi-Poisson regression)、family参数设置为quasipoisson、summary函数获取拟泊松回归模型汇总统计信息
R语言使用glm函数构建拟泊松回归模型(quasi-Poisson regression).family参数设置为quasipoisson.summary函数获取拟泊松回归模型汇总统计信息 目录
- 【NLP】BERT 模型与中文文本分类实践
简介 2018年10月11日,Google发布的论文<Pre-training of Deep Bidirectional Transformers for Language Understan ...
- 【NLP】通俗讲解从Transformer到BERT模型!
作者:陈锴,中山大学 张泽,华东师范大学 近两年来,Bert模型非常受欢迎,大部分人都知道Bert,但却不明白具体是什么.一句话概括来讲:BERT的出现,彻底改变了预训练产生词向量和下游具体NLP任 ...
- 【NLP】一份相当全面的BERT模型精讲
本文概览: 1. Autoregressive语言模型与Autoencoder语言模型 1.1 语言模型概念介绍 Autoregressive语言模型:指的是依据前面(或后面)出现的单词来预测当前时刻 ...
最新文章
- 如何设计四象限电压转换电路?
- java判断是否为数组_JS如何判断是否是数组?
- Tensorflow加载多个模型
- HBase总结(九)Bloom Filter概念和原理
- 当我们谈论生信的时候我们在谈什么
- JavaSE----变量、String、运算符、流程控制
- 微软ASP.NET站点部署指南(11):部署SQL Server数据库更新
- 服务器2012能安装无线网卡,网件无线网卡在windows 2012支持问题
- SpringBoot入门程序HelloWorld
- 我是如何战胜懒惰的?
- Android应用程序消息处理机制(Looper、Handler)分析(3)
- mysql查询行数据_MySQL数据库~~~~~查询行(文件的内容)
- 通俗地讲解傅立叶分析和小波分析间的关系
- 四轴飞行器——转动惯量
- 漫步微积分二十五——面积问题
- CVPR2022论文速递(2022.4.15)!共16篇!内含2篇Oral!
- 基于 SLS 构建 RDS 审计合规监控
- 2022(一等奖)C23“城轨促交融,慢行赋新机”—TOD模式下城市慢行接驳与碳减排辅助出行系统
- 暗影精灵8 Pro 安装 Ubuntu20.04 问题记录
- 丧心病狂!华为折叠屏手机一秒售罄,炒到9万一部!