【MHA】之 Attention Mask (with back forward trace) / Causal Mask (with back trace)
文章目录
- 1. Attention Mask or Causal Mask
- 2. Causal Mask (with n_backtrce)
- 3. Attention Mask with backstrace and forwardtrace
- 4. Customized Mask
在multihead attention
中可添加attention mask
,对输入进行范围限定,如
因果mask (causal mask)
:即可限定只看当前点前面的数据,不可看该点之后的数据。从矩阵上看,causal mask类似一个倒三角,下半部分为1,上半部分为0;因果mask带n_backtrace
:即可限定每一点尽可最多向前看n_backtrace帧。从矩阵上看,即在上面的倒三角中,再在最左侧截去一部分,使得其为宽度为n_backtrace的斜带1;前后向N帧
:即在上述带有n_backtrace的causal mask上,再以同样方式,向前即向右扩展一个宽度为n_backtrace的斜带1;- 类似的,可根据自定义需求,
自行设定mask
。
ref:
MHA TFA 的 实现: https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/layers/multihead_attention.py#L23-L298
1. Attention Mask or Causal Mask
可指定causal参数,来生成普通的attention mask 还是causal mask:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-from tensorflow.keras.layers import Layer, Masking
import tensorflow as tfclass AttentionMask(Layer):"""Computes attention mask."""def __init__(self, causal, mask_value=-1e9):"""Argument/s:causal - causal attention mask flag.mask_value - value used to mask components that aren't to be attendedto (typically -1e9)."""super(AttentionMask, self).__init__()self.causal = causalself.mask_value = mask_valueif not isinstance(mask_value, float): raise ValueError("Mask value must be a float.")def call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###causal_mask = self.lower_triangular_mask([1, max_seq_len, max_seq_len]) if self.causal else None################logical_mask = self.merge_masks(causal_mask, seq_mask)unmasked = tf.zeros([batch_size, max_seq_len, max_seq_len])masked = tf.fill([batch_size, max_seq_len, max_seq_len], self.mask_value)att_mask = tf.where(logical_mask, unmasked, masked)seq_mask = tf.cast(seq_mask, tf.float32)return att_mask, seq_maskdef lower_triangular_mask(self, shape):"""Creates a lower-triangular boolean mask over the last 2 dimensions.Argument/s:shape - shape of mask.Returns:causal mask."""row_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)col_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)return tf.math.greater_equal(row_index, col_index)def merge_masks(self, x, y):"""Merges a sequence mask and a causal mask to make an attantion mask.Argument/s:x - mask.y - mask.Returns:Attention mask."""if x is None: return yif y is None: return xreturn tf.math.logical_and(x, y)
测试:
if __name__ == '__main__':input = tf.ones([64, 526, 40])attention_mask = AttentionMask(causal=0)(input)causal_mask = AttentionMask(causal=1)(input)print('done')
实验结果为:
其中attention mask为:
causal mask为:
2. Causal Mask (with n_backtrce)
即带有n_backtrce的因果mask,继承上面的AttentionMask:
from tensorflow.keras.layers import Masking
import tensorflow as tffrom AttentionMask import AttentionMaskclass AttentionMask_Causal_Backtrace(AttentionMask):"""Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention."""def __init__(self, causal, n_backtrace=None):"""Argument/s:causal - causal attention mask flag.n_backtrace - (int) number of backtrace"""super().__init__(causal)self.causal = causalself.n_backtrace = n_backtracedef call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###causal_mask = self.lower_triangular_mask([batch_size, max_seq_len, max_seq_len]) if self.causal else Nonebt_mask = self.backtrace_mask([1, max_seq_len, max_seq_len]) \if self.causal and self.n_backtrace else None################logical_mask = self.merge_masks(causal_mask, seq_mask)logical_mask = self.merge_masks(logical_mask, bt_mask)att_mask = tf.cast(logical_mask, tf.float32)att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])return att_maskdef backtrace_mask(self, shape):"""Creates a lower-triangular boolean mask over the last 2 dimensions.Argument/s:shape - shape of mask.Returns:causal mask."""row_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)col_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)return tf.math.less_equal(row_index, col_index + self.n_backtrace)
测试:
if __name__ == '__main__':input = tf.ones([64, 526, 40])causal_mask = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=None)(input)causal_mask_backtrace = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=50)(input)print('done')
实验结果:
其中causal_mask为:
causal_mask_backtrace为:
测试样例2:
causal_mask_backtrace = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=5)(input)
3. Attention Mask with backstrace and forwardtrace
from tensorflow.keras.layers import Masking
import tensorflow as tffrom AttentionMask import AttentionMaskclass AttentionMask_Backtrace_Forwardtrace(AttentionMask):"""Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention."""def __init__(self, causal, n_backtrace=None, n_forwardtrace=None):"""Argument/s:causal - causal attention mask flag.n_backtrace - (int) number of backtrace"""super().__init__(causal)self.causal = causalself.n_backtrace = n_backtraceself.n_forwardtrace = n_forwardtracedef call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###bt_ft_mask = self.backtrace_forwardtrace_mask([1, max_seq_len, max_seq_len]) \if self.n_backtrace and self.n_forwardtrace else None################logical_mask = self.merge_masks(bt_ft_mask, seq_mask)att_mask = tf.cast(logical_mask, tf.float32)att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])return att_maskdef backtrace_forwardtrace_mask(self, shape):"""Creates a lower-triangular boolean mask over the last 2 dimensions.Argument/s:shape - shape of mask.Returns:causal mask."""row_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)col_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)bt_mask = tf.math.less_equal(row_index, col_index + self.n_backtrace)ft_mask = tf.math.greater_equal(row_index + self.n_forwardtrace, col_index)bt_ft_mask = self.merge_masks(bt_mask, ft_mask)return bt_ft_mask
测试:
if __name__ == '__main__':input = tf.ones([64, 526, 40])bt_ft_mask = AttentionMask_Backtrace_Forwardtrace(causal=0, n_backtrace=2, n_forwardtrace=5)(input)print('done')
实验结果:
bt_ft_mask为:
4. Customized Mask
class AttentionMask_Customization(AttentionMask):"""Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention."""def __init__(self, causal, trace=None):"""Argument/s:causal - causal attention mask flag.n_backtrace - (int) number of backtrace"""super().__init__(causal)self.causal = causalself.trace = tracedef call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###customized_mask = self.customized_mask(batch_size, max_seq_len, self.trace)################logical_mask = self.merge_masks(customized_mask, seq_mask)att_mask = tf.cast(logical_mask, tf.float32)att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])return att_mask@tf.functiondef customized_mask(self, batchsize, max_length, trace):mask = tf.ones(shape=[batchsize, trace, trace], dtype=tf.int32, name="row")shape_pad = int(max_length - trace)mask = tf.pad(mask, paddings=[[0, 0], [shape_pad, 0], [shape_pad, 0]])mask = tf.cast(mask, dtype=bool)return mask
测试:
if __name__ == '__main__':input = tf.ones([64, 526, 40])customized_mask = AttentionMask_Customization(causal=1, trace=5)(input)print('done')
实验结果:
【MHA】之 Attention Mask (with back forward trace) / Causal Mask (with back trace)相关推荐
- java trace优化_使用Arthas trace定位并优化接口响应慢的问题.md
## Arthas trace命令介绍 > **打印方法内部调用路径,并输出方法路径上的每个节点上耗时.** *trace命令只会trace匹配到的函数里的子调用,并不会向下trace多层.因为 ...
- Mask Scoring Rcnn论文解读《Mask Scoring R-CNN》
参考链接: 论文链接<Mask Scoring R-CNN> Github 地址 Mask Scoring RCNN 在大多数实例分割框架中,实例分类的置信度被用作MASK质量分数. MA ...
- file mask 是什么idea 配置file maskhtml include
目录 file mask 是什么 idea 配置file mask html include <#include "layout/include.ftl" > j ...
- labelme: 用于将黑红色mask标签图转为黑白的mask标签
labelme: 用于将黑红色mask标签图转为黑白的mask标签 import numpy as np import os from PIL import Imagenp.set_printopti ...
- mask rcnn算法原理图_基于MASK RCNN算法实现瑕疵图像识别(训练模型及应用)
数据的准备其实是最花功夫和时间的,度过了上一个阶段后,就进入激动的模型训练阶段了,首先简要聊聊我对Mask RCNN算法的理解: MaskRCNN是何凯明大神基于FastRCNN的改进,2018年初在 ...
- Detectron-MaskRCnn:Mask判别和获取前向Mask的标签
对于FCN-SceneParse网络,最后卷积生成N个类别的maps,每个Map都得到图像所有点的单类概率.MaskRCNN的结构与FCN不相同. 参考三个文章: Detectron总结1:Blob的 ...
- oracle gather trace,Oracle 11g新SQL Trace 10046方法
10046是每一个研究Oracle.进行SQL调优的朋友非常熟悉的工具.10046和10053两个诊断事件,可以方便的帮助我们了解Oracle CBO优化 10046是每一个研究Oracle.进行SQ ...
- oracle trace发起用户,Oracle 使用TRACE进行SQL性能分析
设置sql_trace参数为true会对整个实例进行跟踪,包括所有进程:用户进程和后台进程,会造成比较严重的性能问题,生产环境一定要慎 设置sql_trace参数为true会对整个实例进行跟踪,包括所 ...
- oracle停止trace日志,关闭ORACLE客户端trace日志
TRACE时的注意事项1.确保所需trace组建的配置文件存在 缺省情况下,Oracle会从下列位置搜索网络配置文件 a.TNS_ADMIN environment variable (incl. W ...
最新文章
- 浅显易懂 Makefile 入门 (09)— include 文件包含、MAKECMDGOALS
- 网站优化四大优势必须了解
- Java并发编程:Synchronized底层优化(偏向锁、轻量级锁)
- 如何利用ide进行跟踪调试_使用调试器进行事后跟踪
- 使用run-rs启动mongodb
- java输出变量_Java笔记1: 输入输出与变量常量
- i5+GT730+B85安装OSX10.10.5 (Yosemite Install(14F27).cdr)
- Prism 的 TabControl 导航
- SSH框架和Redis的整合(1)
- Filenet:主打底层技术创新,检索分发挖矿开创全民挖矿时代!
- 1.2 数值分析 误差的来源和分类
- 气象研究中的大气稳定性 Atmosphere stability
- 计算机应用技术和it有什么区别,IT是程序员吗?IT究竟是什么意思?
- 截止失真放大电路_【电子干货377】晶体三极管的一些常见应用电路
- pta计算个人所得税
- Spark面试题修改版本
- 如何设置条码标签的打印数量
- 搭建nexus私服:nexus-3.19.1-01
- mac重装系统之后删除容器中的其他卷宗内容
- linux系统的格式化说明,格式化[说明]如何用LINUX命令格式化U盘