关于Bert模型参数的分布 简单的数学题,另外 M代表百万个_12-layer, 768-hidden, 12-heads, 110m parameters_愚昧之山绝望之谷开悟之坡的博客-程序员宅基地

技术标签: NLP基础知识  工具  自然语言处理  bert  神经网络  

参数分布

Bert模型的版本如下:

BERT-Base, Uncased: 12-layer, 768-hidden, 12-heads, 110M parameters

BERT-Large, Uncased: 24-layer, 1024-hidden, 16-heads, 340M parameters

BERT-Base, Cased: 12-layer, 768-hidden, 12-heads , 110M parameters

BERT-Large, Cased: 24-layer, 1024-hidden, 16-heads, 340M parameters

BERT-Base, Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters

所有的输入长度最长为512,英文版的"vocab_size": 30522,中文版的"vocab_size": 21128

其中的parameters的数量很少有人关注过,下面就Bert模型中的参数进行一个简单的分析。

Bert模型的输入要做一个向量化,提取每个词的三维信息,如图:

因为,句子的长度不一样,而我们向量化的时候把每个输入都做成了512的长度,这里就需要一个类似补0的操作,我认为这里的表达形式也是一种参数的学习。因此在Embedding层也是有一个权重参数的。

 

然后就把处理好的向量输入到12个transformer blocks中,每一个transformer block(bert只是使用到transformer的encoder)中包含了encoder模块,具体的构成参考transformer相关文献。在encoder中主要是包含了self-Attention,前馈神经网络和归一化功能。

Self-Attention的主要功能是在编码当前单词的时候能够同时关注到上下文中和它有关的单词,在实现层面上,简单的说就是3个矩阵的自乘,矩阵计算形式:

这个Q、k和V都是要经过和网络的学习得到的权重矩阵做运算得到的,这里很有可能就涉及到大量的参数。

 

另外前馈神经网络做了全连接,每个连接需要对应的一个权重值——也就是权重参数。

 

根据具体的实现代码:logits_lm = self.decoder(h_masked)得知,最后在transformer blocks的后面还接入了一个全连接。这里也有一部分参数。

总体来说bert模型的参数主要包含3部分:Embeddding层的参数,transformer blocks的参数和最后输出的全连接参数。

第一部分的参数:

30522*768+512*768+4*768

第二部分参数:

【(768*768+768)*4+(768*2)+(3072*768*2+3072)+768*3】*12

第三部分参数:

768*768+768

 

参数个数总计:109482240~1.09亿

而BERT-Base, Chinese BERT-Base, Chinese总是约为1.02亿。

 

代码:


  
  
   
  1. num_weights= 0
  2. for name, param in model.state_dict().items(): #model为任意加载进来的一个bert模型
  3.     if len(param.shape)== 1:
  4.         num_weights+=param.shape[ 0]
  5.         print(name, param.shape,end= '  ')
  6.         print( '参数个数为:',param.shape[ 0])
  7.     else:
  8.         num_weights += param.shape[ 0]*param.shape[ 1]
  9.         print(name, param.shape, end= '  ')
  10.         print( '参数个数为:', param.shape[ 0]*param.shape[ 1])
  11. print( '参数总数:',num_weights)

附录——部分参数:

bert.embeddings.word_embeddings.weight torch.Size([30522, 768])  参数个数为: 23440896

bert.embeddings.position_embeddings.weight torch.Size([512, 768])  参数个数为: 393216

bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])  参数个数为: 1536

bert.embeddings.LayerNorm.weight torch.Size([768])  参数个数为: 768

bert.embeddings.LayerNorm.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])  参数个数为: 589824

bert.encoder.layer.0.attention.self.query.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])  参数个数为: 589824

bert.encoder.layer.0.attention.self.key.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])  参数个数为: 589824

bert.encoder.layer.0.attention.self.value.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])  参数个数为: 589824

bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])  参数个数为: 2359296

bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072])  参数个数为: 3072

bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072])  参数个数为: 2359296

bert.encoder.layer.0.output.dense.bias torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.output.LayerNorm.weight torch.Size([768])  参数个数为: 768

bert.encoder.layer.0.output.LayerNorm.bias torch.Size([768])  参数个数为: 768

 

参考文献:

https://www.cnblogs.com/jiangxinyang/p/11422975.html

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_15821487/article/details/120501017

智能推荐

linux运维脚本样例,Linux运维处理及监控脚本【Linux运维之道之脚本案例】-程序员宅基地

Linux运维处理及监控脚本【Linux运维之道之脚本案例】Linux日常运维工作一个一个命令搞是一件苦事情,作为新一代IT运维工作者,在工作中不断探索提升效率方法和经验,摸索出不少脚本,减轻日常工作量。脚本的出现后给运维工作带来一扇曙光。以下让我们一起走进Linux脚本的世界,一起分享这份经验。分享场景一:运维过程通过脚本批量删除文件。运维时遇到在tmp文件目录下存放大最的ppxx__*的临时..._[场景一]假设客户现场的应用服务器出现问题或故障,设计一组脚本对该服务器进行一

コミュニティパートナユーザ レコード個別共有 共有セット_セックスコミック共有-程序员宅基地

【Salesforce】コミュニティユーザライセンスについて 共有セットを用いたレコードアクセス制御 Salesforce: Communityの共有セットでのデータ共有で気をつけること コミュニティユーザライセンスhttps://help.salesforce.com/articleView?id=users_license_types_communities.htm&type=5パートナーユーザロールhttps://help.salesforce.com/a..._セックスコミック共有

自定义MVC详解_servlet 自定义mvc-程序员宅基地

J2EE目录J2EE1、什么是mvc2、mvc组成部分3、搭建mvc(代码)1、中央控制器:DispathServlet2、模型驱动:ModelDriver3、子控制器接口:Action4、增强子控制器:ActionSupport5、具体子控制器案例BookServlet6、Dao层7、XML配置文件8、web.xml4、界面效果图5、mvc的执行流程1、简单流程,首先大致感受一下他的流程2、详细流程,建议按照这个流程图来理解每个类之间的关_servlet 自定义mvc

安卓-控制EditText的光标的位置_edittext控制光标位置-程序员宅基地

EditText光标的位置的控制,主要是依赖于属性setSelection,传入光标的位置索引即可。下面分三种情况测试:1)设置光标在文字的起始位置2)设置光标在文字的中间位置3)设置光标在文字的末尾位置布局文件activity_edit_text_cursor.xml

从RNN到LSTM--公式描述_rnn lstm 公式公式-程序员宅基地

好久没写博客了,近几天重新看LSTM,发现有很多细节之前没有理解到位,并且至今有一些疑惑。接下来从RNN谈起,利用公式描述,并结合tensorflow函数说明两个容易混淆的细节。之后讲解LSTM,主要参考自大神Alex Graves论文《supervised sequence labelling with recurrent neural networks》,同时加上自己的理解,如果不对,请大家指..._rnn lstm 公式公式

Java基础入门 窗体事件_java窗体事件-程序员宅基地

大部分GUI应用程序都需要Window窗体对象作为最外层容器,可以说窗体对象是所有GUI应用程序的基础。在JDK中提供了一个类WindowEvent用于表示窗体事件,在应用程序中当对窗体时间进行处理时,首先需要定义一个类实现WindowListener接口作为窗体监听器,然后通过addWindowListener()方法将窗体对象和窗体监听器绑定。接下来通过一个案例演示:import..._java窗体事件

随便推点

Shell编程-程序员宅基地

Shell编程 作者:Ackarlix 在DOS 中,你可能会从事一些例行的重覆性工作,此时你会将这些重覆性的命令写成批次档,只要执行这个批次档就等於执行这些命令。大家会问在UNIX中是否有批次处理这个东东,答案是有的。在UNIX中不只有如DOS 的批次处理,它的功能比起DOS 更强大,相对地也较复杂,已经和一般的高阶语言不相上下。在UNIX中大家都不叫做批次档,而叫做She

移动互联网之路——Axure RP 8.0网站与APP原型设计从入门到精通-程序员宅基地

1.1  了解 Axure RPAxure RP 是一个专业的快速原型设计工具。 Axure(Ack-sure) 代表美国 Axure 公司; RP则是 Rapid Prototypi...

Project ERROR:Unknown module(s) in QT: charts解决办法_我的qt为什么没有qchart-程序员宅基地

这种错误是因为找不到charts的modules,通常出现在移植的时候。这一般是因为在安装QT的时候没有安装Qt Charts,因为Qt Charts默认不安装已经安装过QT的,可以运行MaintenanceTool添加组件,也可以卸载后重新安装。安装时勾选Qt Charts,安装Charts组件安装过charts组件后就可以成功构建啦!_我的qt为什么没有qchart

Python读取配置文件_python怎么读取配置文件-程序员宅基地

本文章介绍的是Python中,后缀是yml、ini和py三种方式的配置文件1)、py -- 即是把配置项的内容写在Python文件,这种方式的配置文件,读取的方式最直接、也是最简单配置文件 -- config.py :# -*- coding:utf-8 -*-'''DATABASE'''db_host = '127.0.0.1'db_port = 3306db_userna..._python怎么读取配置文件

c语言在数组输出字母,c语言字符数组与字符串的使用详解_伊噜咔的博客-程序员宅基地

1、字符数组的定义与初始化字符数组的初始化,最容易理解的方式就是逐个字符赋给数组中各元素。char str[10]={ 'I',' ','a','m',' ',‘h','a','p','p','y'};即把10个字符分别赋给str[0]到str[9]10个元素如果花括号中提供的字符个数大于数组长度,则按语法错误处理;若小于数组长度,则只将这些字符数组中前面那些元素,其余的元素自动定为空字符(即 '...