文章目录

  • 1. Attention Mask or Causal Mask
  • 2. Causal Mask (with n_backtrce)
  • 3. Attention Mask with backstrace and forwardtrace
  • 4. Customized Mask

multihead attention 中可添加attention mask,对输入进行范围限定,如

  • 因果mask (causal mask):即可限定只看当前点前面的数据,不可看该点之后的数据。从矩阵上看,causal mask类似一个倒三角,下半部分为1,上半部分为0;
  • 因果mask带n_backtrace:即可限定每一点尽可最多向前看n_backtrace帧。从矩阵上看,即在上面的倒三角中,再在最左侧截去一部分,使得其为宽度为n_backtrace的斜带1;
  • 前后向N帧:即在上述带有n_backtrace的causal mask上,再以同样方式,向前即向右扩展一个宽度为n_backtrace的斜带1;
  • 类似的,可根据自定义需求,自行设定mask

ref:
MHA TFA 的 实现: https://github.com/tensorflow/addons/blob/v0.15.0/tensorflow_addons/layers/multihead_attention.py#L23-L298

1. Attention Mask or Causal Mask

可指定causal参数,来生成普通的attention mask 还是causal mask:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-from tensorflow.keras.layers import Layer, Masking
import tensorflow as tfclass AttentionMask(Layer):"""Computes attention mask."""def __init__(self, causal, mask_value=-1e9):"""Argument/s:causal - causal attention mask flag.mask_value - value used to mask components that aren't to be attendedto (typically -1e9)."""super(AttentionMask, self).__init__()self.causal = causalself.mask_value = mask_valueif not isinstance(mask_value, float): raise ValueError("Mask value must be a float.")def call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###causal_mask = self.lower_triangular_mask([1, max_seq_len, max_seq_len]) if self.causal else None################logical_mask = self.merge_masks(causal_mask, seq_mask)unmasked = tf.zeros([batch_size, max_seq_len, max_seq_len])masked = tf.fill([batch_size, max_seq_len, max_seq_len], self.mask_value)att_mask = tf.where(logical_mask, unmasked, masked)seq_mask = tf.cast(seq_mask, tf.float32)return att_mask, seq_maskdef lower_triangular_mask(self, shape):"""Creates a lower-triangular boolean mask over the last 2 dimensions.Argument/s:shape - shape of mask.Returns:causal mask."""row_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)col_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)return tf.math.greater_equal(row_index, col_index)def merge_masks(self, x, y):"""Merges a sequence mask and a causal mask to make an attantion mask.Argument/s:x - mask.y - mask.Returns:Attention mask."""if x is None: return yif y is None: return xreturn tf.math.logical_and(x, y)

测试:

if __name__ == '__main__':input = tf.ones([64, 526, 40])attention_mask = AttentionMask(causal=0)(input)causal_mask = AttentionMask(causal=1)(input)print('done')

实验结果为:

其中attention mask为:

causal mask为:

2. Causal Mask (with n_backtrce)

即带有n_backtrce的因果mask,继承上面的AttentionMask:

from tensorflow.keras.layers import Masking
import tensorflow as tffrom AttentionMask import AttentionMaskclass AttentionMask_Causal_Backtrace(AttentionMask):"""Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention."""def __init__(self, causal, n_backtrace=None):"""Argument/s:causal - causal attention mask flag.n_backtrace - (int) number of backtrace"""super().__init__(causal)self.causal = causalself.n_backtrace = n_backtracedef call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###causal_mask = self.lower_triangular_mask([batch_size, max_seq_len, max_seq_len]) if self.causal else Nonebt_mask = self.backtrace_mask([1, max_seq_len, max_seq_len]) \if self.causal and self.n_backtrace else None################logical_mask = self.merge_masks(causal_mask, seq_mask)logical_mask = self.merge_masks(logical_mask, bt_mask)att_mask = tf.cast(logical_mask, tf.float32)att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])return att_maskdef backtrace_mask(self, shape):"""Creates a lower-triangular boolean mask over the last 2 dimensions.Argument/s:shape - shape of mask.Returns:causal mask."""row_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)col_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)return tf.math.less_equal(row_index, col_index + self.n_backtrace)

测试:

if __name__ == '__main__':input = tf.ones([64, 526, 40])causal_mask = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=None)(input)causal_mask_backtrace = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=50)(input)print('done')

实验结果:

其中causal_mask为:

causal_mask_backtrace为:

测试样例2:

causal_mask_backtrace = AttentionMask_Causal_Backtrace(causal=1, n_backtrace=5)(input)

3. Attention Mask with backstrace and forwardtrace

from tensorflow.keras.layers import Masking
import tensorflow as tffrom AttentionMask import AttentionMaskclass AttentionMask_Backtrace_Forwardtrace(AttentionMask):"""Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention."""def __init__(self, causal, n_backtrace=None, n_forwardtrace=None):"""Argument/s:causal - causal attention mask flag.n_backtrace - (int) number of backtrace"""super().__init__(causal)self.causal = causalself.n_backtrace = n_backtraceself.n_forwardtrace = n_forwardtracedef call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###bt_ft_mask = self.backtrace_forwardtrace_mask([1, max_seq_len, max_seq_len]) \if self.n_backtrace and self.n_forwardtrace else None################logical_mask = self.merge_masks(bt_ft_mask, seq_mask)att_mask = tf.cast(logical_mask, tf.float32)att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])return att_maskdef backtrace_forwardtrace_mask(self, shape):"""Creates a lower-triangular boolean mask over the last 2 dimensions.Argument/s:shape - shape of mask.Returns:causal mask."""row_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="row"), axis=-2)col_index = tf.math.cumsum(tf.ones(shape=shape, dtype=tf.int32, name="col"), axis=-1)bt_mask = tf.math.less_equal(row_index, col_index + self.n_backtrace)ft_mask = tf.math.greater_equal(row_index + self.n_forwardtrace, col_index)bt_ft_mask = self.merge_masks(bt_mask, ft_mask)return bt_ft_mask

测试:

if __name__ == '__main__':input = tf.ones([64, 526, 40])bt_ft_mask = AttentionMask_Backtrace_Forwardtrace(causal=0, n_backtrace=2, n_forwardtrace=5)(input)print('done')

实验结果:

bt_ft_mask为:

4. Customized Mask

class AttentionMask_Customization(AttentionMask):"""Computes attention mask appropriate for tf.keras.layers.MultiHeadAttention."""def __init__(self, causal, trace=None):"""Argument/s:causal - causal attention mask flag.n_backtrace - (int) number of backtrace"""super().__init__(causal)self.causal = causalself.trace = tracedef call(self, inp):"""Compute attention mask.Argument/s:inp - used to compute sequence mask.Returns:Attention mask."""batch_size = tf.shape(inp)[0]max_seq_len = tf.shape(inp)[1]flat_seq_mask = Masking(mask_value=0.0).compute_mask(inp)seq_mask = self.merge_masks(tf.expand_dims(flat_seq_mask, axis=1), tf.expand_dims(flat_seq_mask, axis=2))### HERE !!! ###customized_mask = self.customized_mask(batch_size, max_seq_len, self.trace)################logical_mask = self.merge_masks(customized_mask, seq_mask)att_mask = tf.cast(logical_mask, tf.float32)att_mask = tf.reshape(att_mask, [batch_size, 1, max_seq_len, max_seq_len])return att_mask@tf.functiondef customized_mask(self, batchsize, max_length, trace):mask = tf.ones(shape=[batchsize, trace, trace], dtype=tf.int32, name="row")shape_pad = int(max_length - trace)mask = tf.pad(mask, paddings=[[0, 0], [shape_pad, 0], [shape_pad, 0]])mask = tf.cast(mask, dtype=bool)return mask

测试:

if __name__ == '__main__':input = tf.ones([64, 526, 40])customized_mask = AttentionMask_Customization(causal=1, trace=5)(input)print('done')

实验结果:

【MHA】之 Attention Mask (with back forward trace) / Causal Mask (with back trace)相关推荐

  1. java trace优化_使用Arthas trace定位并优化接口响应慢的问题.md

    ## Arthas trace命令介绍 > **打印方法内部调用路径,并输出方法路径上的每个节点上耗时.** *trace命令只会trace匹配到的函数里的子调用,并不会向下trace多层.因为 ...

  2. Mask Scoring Rcnn论文解读《Mask Scoring R-CNN》

    参考链接: 论文链接<Mask Scoring R-CNN> Github 地址 Mask Scoring RCNN 在大多数实例分割框架中,实例分类的置信度被用作MASK质量分数. MA ...

  3. file mask 是什么idea 配置file mask​html include

    目录 file mask 是什么 idea 配置file mask ​ html  include <#include "layout/include.ftl" > j ...

  4. labelme: 用于将黑红色mask标签图转为黑白的mask标签

    labelme: 用于将黑红色mask标签图转为黑白的mask标签 import numpy as np import os from PIL import Imagenp.set_printopti ...

  5. mask rcnn算法原理图_基于MASK RCNN算法实现瑕疵图像识别(训练模型及应用)

    数据的准备其实是最花功夫和时间的,度过了上一个阶段后,就进入激动的模型训练阶段了,首先简要聊聊我对Mask RCNN算法的理解: MaskRCNN是何凯明大神基于FastRCNN的改进,2018年初在 ...

  6. Detectron-MaskRCnn:Mask判别和获取前向Mask的标签

    对于FCN-SceneParse网络,最后卷积生成N个类别的maps,每个Map都得到图像所有点的单类概率.MaskRCNN的结构与FCN不相同. 参考三个文章: Detectron总结1:Blob的 ...

  7. oracle gather trace,Oracle 11g新SQL Trace 10046方法

    10046是每一个研究Oracle.进行SQL调优的朋友非常熟悉的工具.10046和10053两个诊断事件,可以方便的帮助我们了解Oracle CBO优化 10046是每一个研究Oracle.进行SQ ...

  8. oracle trace发起用户,Oracle 使用TRACE进行SQL性能分析

    设置sql_trace参数为true会对整个实例进行跟踪,包括所有进程:用户进程和后台进程,会造成比较严重的性能问题,生产环境一定要慎 设置sql_trace参数为true会对整个实例进行跟踪,包括所 ...

  9. oracle停止trace日志,关闭ORACLE客户端trace日志

    TRACE时的注意事项1.确保所需trace组建的配置文件存在 缺省情况下,Oracle会从下列位置搜索网络配置文件 a.TNS_ADMIN environment variable (incl. W ...

最新文章

  1. 浅显易懂 Makefile 入门 (09)— include 文件包含、MAKECMDGOALS
  2. 网站优化四大优势必须了解
  3. Java并发编程:Synchronized底层优化(偏向锁、轻量级锁)
  4. 如何利用ide进行跟踪调试_使用调试器进行事后跟踪
  5. 使用run-rs启动mongodb
  6. java输出变量_Java笔记1: 输入输出与变量常量
  7. i5+GT730+B85安装OSX10.10.5 (Yosemite Install(14F27).cdr)
  8. Prism 的 TabControl 导航
  9. SSH框架和Redis的整合(1)
  10. Filenet:主打底层技术创新,检索分发挖矿开创全民挖矿时代!
  11. 1.2 数值分析 误差的来源和分类
  12. 气象研究中的大气稳定性 Atmosphere stability
  13. 计算机应用技术和it有什么区别,IT是程序员吗?IT究竟是什么意思?
  14. 截止失真放大电路_【电子干货377】晶体三极管的一些常见应用电路
  15. pta计算个人所得税
  16. Spark面试题修改版本
  17. 如何设置条码标签的打印数量
  18. 搭建nexus私服:nexus-3.19.1-01
  19. mac重装系统之后删除容器中的其他卷宗内容
  20. linux系统的格式化说明,格式化[说明]如何用LINUX命令格式化U盘

热门文章

  1. 日本免水羽绒服清洗剂,好用到让干洗店倒闭!
  2. 操作系统——进程之处理机调度
  3. 存储论(二):有约束的确定型存贮模型、单周期随机库存模型
  4. 经典文献阅读之--OV2SLAM(高速视觉slam)
  5. HDU3338 (建图原理详解)
  6. wireshark二次开发
  7. 统计学概览与统计检验总结
  8. 机器学习 K均值聚类(K-means) 鸢尾花数据集
  9. 爬取电影资源之应用下载篇
  10. PHP5 session