t5模型中encoder与decoder内容不同

  • 查看transformers库之中的encoder和decoder部分内容的不同
  • 综合分析t5LayerSelfAttention和t5LayerCrossAttention的运行的不同
  • bert4keras t5decoder解读
  • BaseModelOutputWithPastAndCrossAttention解读

关于transformers整体结构的解答可以查看相应的解析:解析网站
本质上t5使用的是编码和解码的操作,transformers的网络结构如下:
首先需要理解这个transformer对应的结构图,比如我们要想通过输入我爱中国得到输出I love China,那么Inputs输入永远是我爱中国,而Outputs刚开始为+position encoding,接下来产生预测I了之后,继续将Outputs(shifted right)变为(起始符)+I+position encoding,然后产生预测love之后,继续将Outputs(shifted right)变为(起始符)+I+love+positional encoding,以此类推。
由此可见,Inputs的部分始终不变,Outputs(shifted right)部分在不断地变化,从而引起预测结果不断地改变。
此外这种encoder-decoder结构还引出了一种attention的变化,也就是说在t5模型之中,encoder部分的attention与decoder中第二个部分的attention结构一致,decoder attention中第一个部分的attention加入了mask掩码的内容,这与bert4keras中的代码保持一致。

查看transformers库之中的encoder和decoder部分内容的不同

仔细观察发现,t5selfattention和t5crossattention的区别在于t5crossattention之中多加入了两个参数
t5selfattention的内容

self_attention_outputs = self.layer[0](hidden_states,attention_mask=attention_mask,position_bias=position_bias,layer_head_mask=layer_head_mask,past_key_value=self_attn_past_key_value,use_cache=use_cache,output_attentions=output_attentions,
)

t5crossattention的内容

cross_attention_outputs = self.layer[1](hidden_states,key_value_states=encoder_hidden_states,attention_mask=encoder_attention_mask,position_bias=encoder_decoder_position_bias,layer_head_mask=cross_attn_layer_head_mask,past_key_value=cross_attn_past_key_value,query_length=query_length,use_cache=use_cache,output_attentions=output_attentions,
)

可以看出来,上文的cross_attention_outputs之中多出了两个参数:key_value_states和query_length,所以这里重点看key_value_states和query_length对cross_attention_outputs造成的影响(即key_value_states和query_length对T5Attention造成的影响)
所以接下来,我们需要进入到代码之中,去查看key_value_states以及query_length对T5Attention造成的影响

综合分析t5LayerSelfAttention和t5LayerCrossAttention的运行的不同

1.先进行encoder的部分
t5LayerSelfAttention的输入:6个T5LayerSelfAttention
输入的内容(1,11,512)
2.再进行decoder的部分
t5LayerSelfAttention的输入:6个T5LayerSelfAttention
输入的内容(1,1,512)
t5LayerCrossAttention的输入:6个T5CrossAttention
输入的内容(1,1,512)
这里面的维度变化在T5ForConditionalGeneration中有改变过

if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:# get decoder inputs from shifting lm labels to the rightdecoder_input_ids = self._shift_right(labels)

本身decoder_input_ids = (1,11,512),经过_shift_right之后变成了(1,1,512)
T5ForConditionalGeneration->T5PreTrainedModel->PreTrainedModel->GenerationMixin->generate函数
最终找出来是在transformers中的generation_utils.py之中找出来的

if "decoder_input_ids" in model_kwargs:input_ids = model_kwargs.pop("decoder_input_ids")print('111input_ids = 111')print(input_ids)print('111111111111111111')
else:input_ids = self._prepare_decoder_input_ids_for_generation(input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id)print('222input_ids = 222')print(input_ids)print('222222222222222222')

这里输入的各种参数

input_ids =
tensor([[13959,  1566,    12,  2968,    10,    37,   629,    19,  1627,     5,1,     0,     0,     0,     0]])
decoder_start_token_id =
None
bos_token_id =
None

出来之后的

input_ids = 222
tensor([[0]])

然后经历6轮的LayerSelfAttention和LayerCrossAttention网络层部分,最后出来了(1,1,512)的tensor内容
decoder出来的内容部分如下:

decoder_outputs = self.decoder(input_ids=decoder_input_ids,attention_mask=decoder_attention_mask,inputs_embeds=decoder_inputs_embeds,past_key_values=past_key_values,encoder_hidden_states=hidden_states,encoder_attention_mask=attention_mask,head_mask=decoder_head_mask,cross_attn_head_mask=cross_attn_head_mask,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)sequence_output = decoder_outputs[0]

形成的sequence_output = (1,1,512)
输出的各项参数为

decoder_outputs.past_key_values = ((tensor([[[[-4.8649e-01, -2.3323e+00, -1.1428e+00,  4.3997e-01, -3.7448e+00,5.9211e-01,  1.7371e+00,  1.2648e-01, -9.5232e-01,  6.4317e-01,6.5032e-02, -2.7661e+00, -2.9257e-01, -2.2728e+00,  1.4708e+00,3.6940e+00, -5.9305e-01, -2.2253e+00, -2.2925e+00,  1.2926e+00,-1.6622e+00,  1.5806e-01, -9.8186e-01,  6.9422e-01, -2.3424e+00,-1.5638e-01, -9.2692e-01,  2.5009e+00,  1.5147e+00, -3.6560e-01,3.0006e-01,  9.5156e-01,  2.0886e+00,  3.6983e-01, -1.0588e+00,2.6796e+00, -1.4096e+00, -1.1152e+00, -2.3030e+00, -1.3433e+00,1.9916e+00,  2.5363e-03, -2.4754e+00,  7.7748e-01, -1.2229e+00,-1.9101e+00,  1.9616e+00, -1.2805e+00,  1.0394e+00,  1.6140e-01,-2.3916e-01,  2.9783e-01,  1.6426e+00, -1.3518e+00, -1.1187e+00,-1.4495e+00, -2.1039e+00,  2.9519e+00, -1.8293e+00,  1.2496e+00,6.0215e-01, -2.5693e+00, -1.7539e+00, -5.6927e-01]],............[[ 3.8271e-01,  3.9878e-01,  3.0701e-01,  ...,  2.0659e+00,1.1919e+00,  9.1220e-01],[ 2.2469e-01, -1.3852e+00, -2.3070e-01,  ..., -4.1294e+00,-4.6317e+00, -6.0171e-01],[ 4.2913e-02, -2.8669e-01,  1.4512e-01,  ...,  2.4677e-01,3.0281e-02,  6.6158e-01],...,[ 8.8233e+00,  1.4664e+00, -6.6772e+00,  ...,  5.7047e+00,3.9132e+00,  4.7790e+00],[ 5.3985e+00, -9.5581e-01, -2.2232e+00,  ...,  7.3522e+00,1.5856e+00, -7.5307e+00],[-1.8160e+00, -2.0803e+00, -9.2405e-01,  ...,  1.6660e+00,1.1615e+00, -1.7454e-01]]]])))
decoder_outputs.hidden_states = None
decoder_outputs.attentions = None
decoder_outputs.cross_attentions = None
encoder_outputs.last_hidden_state =
tensor([[[ 0.0154,  0.1263,  0.0301,  ..., -0.0117,  0.0373,  0.1015],[-0.1926, -0.1285,  0.0228,  ..., -0.0339,  0.0535,  0.1575],[ 0.0109, -0.0210,  0.0022,  ...,  0.0008, -0.0056, -0.0393],...,[-0.1581, -0.0719,  0.0208,  ..., -0.1778,  0.1037, -0.1703],[ 0.0142, -0.1430,  0.0148,  ...,  0.0224, -0.1906, -0.0547],[ 0.0756, -0.0119, -0.0273,  ..., -0.0044, -0.0505,  0.0554]]])
encoder_hidden_states =
None
encoder_attentions =
None

这里输出的encoder_output.last_hidden_state = (1,11,512)
decoder_output.past_key_values每一个的形状为(1,8,1,64)不知道是用来干什么的
仔细观察decoder_output.past_key_values,发现需要转到t5stack类别的最后面进行查看

print('t5stack past_key_values = ')
print(past_key_values)
print('--------------------------')
return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states,past_key_values=present_key_value_states,hidden_states=all_hidden_states,attentions=all_attentions,cross_attentions=all_cross_attentions,
)

在进入BaseModelOutputWithPastAndCrossAttentions之前,获得的past_key_values = [None, None, None, None, None, None]
(6个网络层)
也就是说past_key_values为进入BaseModelOutputWithPastAndCrossAttention之后才变换的,即是在modeling_t5.py之中的

return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states,past_key_values=present_key_value_states,hidden_states=all_hidden_states,attentions=all_attentions,cross_attentions=all_cross_attentions,
)

返回之前,

past_key_values = None

经过阅读代码发现,这里的decoder每次都输入的为一个一维的数值

decoder_input_ids =
tensor([[0]])
decoder_input_ids =
tensor([[644]])
decoder_input_ids =
tensor([[4598]])
decoder_input_ids =
tensor([[229]])
decoder_input_ids =
tensor([[19250]])
decoder_input_ids =
tensor([[5]])

也就是说,这里经历了

decoder_input_ids = self._shift_right(labels)

右移之后

bert4keras t5decoder解读

***inputs = ***
[<tf.Tensor 'Input-Context:0' shape=(?, ?, 768) dtype=float32>, <tf.Tensor 'Decoder-Input-Token:0' shape=(?, ?) dtype=float32>]

t5 decoder的输入为encoder输出和原始的token_ids???

BaseModelOutputWithPastAndCrossAttention解读

transformers5--t5模型中encoder与decoder内容不同解读相关推荐

  1. R语言回归模型中的Pr(>|t|)如何解读?Pr(>|t|)如何计算?

    R语言回归模型中的Pr(>|t|)如何解读?Pr(>|t|)如何计算? 目录 |t|)如何解读?Pr(>|t|)如何计算?">R语言回归模型中的Pr(>|t|) ...

  2. transformer t5代码解读4(主要内容bert4keras实现t5模型)

    继续解读t5代码之中源码的内容 回到t5的整体结构之中 T5CrossAttention网络层结构的调用 t5的Encoder和Decoder内容比对 bert4keras调用t5模型之中的encod ...

  3. encoder decoder模型_如何突破Decoder性能瓶颈?揭秘FasterTransformer的原理与应用

    位来 发自 凹非寺 量子位 报道 | 公众号 QbitAI 4月9日,英伟达x量子位分享了一期nlp线上课程,来自NVIDIA的GPU计算专家.FasterTransformer 2.0开发者之一的薛 ...

  4. 从Encoder到Decoder实现Seq2Seq模型

    首发于机器不学习 关注专栏 写文章 从Encoder到Decoder实现Seq2Seq模型 天雨粟 模型师傅 / 果粉 ​ 关注他 300 人赞同了该文章 更新:感谢@Gang He指出的代码错误.g ...

  5. encoder decoder模型_3分钟|聊一聊 Decoder 模块

    微信公众号:NLP从入门到放弃 本文大概需要阅读 4.1 分钟 聊一下对 Decoder 的个人总结和理解,我保证里面大部分内容你没在别人那看过,绝对原创. 我先说一个很小的细节点,当时花了点时间才琢 ...

  6. Transformer中的encoder和decoder在训练和推理过程中究竟是如何工作的

    Transformer中的encoder和decoder在训练和推理过程中究竟是如何工作的 苦苦冲浪,找不到答案 Transformer结构(随便冲浪均可查到) Transformer推理过程 Tra ...

  7. 终结篇:t5模型结构的阅读

    问题关键:past_key_value 模型的整体结构(由外到内) 最外层generation_utils.py之中的greedy_search调用模型解读 t5Stack模型的解读 t5block网 ...

  8. T5 模型:NLP Text-to-Text 预训练模型+数据清洗

    简单总结T5模型:         T5模型:是一个端到端,text-to-text 预训练模型         T5模型也是训练七十个模型中一个较通用的一个框架.         T5模型:可以做文 ...

  9. 模型学习之T5模型初探

    T5谷歌19年发布一个的一个模型,它也一度刷了榜,最主要的贡献是提出一个通用框架,接着进行了各种比对实验,获得一套建议参数,最后得到一个很强的 baseline.而我们之后做这方面实验就能参考它的一套 ...

最新文章

  1. ALinq 入门学习(八)--ALinq 对Vs2010 的支持
  2. 【转】Dubbo_与Zookeeper、SpringMVC整合和使用(负载均衡、容错)
  3. 高级会计师计算机考试中级,会计师需要计算机等级考试吗
  4. vue 项目加载顺序_如何提高Vue项目首页的加载速度
  5. 消防信号总线原理_AFPM100/B消防设备电源监控系统在百色市人民医院消防设备电源监控系统的应用-安科瑞 华梅超...
  6. 关于计算机的知识古人,世界仅是一串二进制编码?我们是虚拟的?古人早就给出了答案...
  7. Scratch(二十七):恐龙飞奔
  8. 平面杆系结构有限元分析C++程序设计思路
  9. sqliteman安装时出现The following packages have unmet dependencies: libqtgui4 : Depends: libpng12-0错误
  10. 传输层常见的协议及端口
  11. java分发器 及(注解 + 反射机制)—————— 开开开山怪
  12. pandas数据清洗策略1
  13. Cg Programming In Unity Projection of Bumpy Surfaces
  14. UI设计培训主要学习哪些内容
  15. Java调用阿里云OSS下载文件
  16. iMeta | 南科大夏雨组纳米孔测序揭示微生物可减轻高海拔冻土温室气体排放
  17. rest php,prest
  18. PHP定时任务 - PHP自动定时循环执行任务实例代码
  19. Ubuntu20.04编译并运行imu_utils,并且标定IMU
  20. suspicious number

热门文章

  1. 怎样检测计算机硬件是否正常,怎么检查电脑硬件是否有问题
  2. Materials Studio画苯环
  3. Matlab R2017b 自动驾驶工具箱学习笔记(2)_Tutorials_Visual Perception Using Monocular Camera
  4. 3D目标检测——代码理解——Second代码:数据处理kitti_dataset.py的理解
  5. 二代测序的原理和简介
  6. 大厂必考深度学习算法面试题
  7. 【经典收藏】深度技术ghost官方原版XP系统sp3下载地址 ...
  8. LeetCode 276:栅栏涂色
  9. CTF之旅WEB篇(3)--ezunser PHP反序列化
  10. 靶机渗透----bulldog2