TensorFlow实现Attention机制
六月 北京 | 高性能计算之GPU CUDA培训
6月22-24日三天密集式学习 快速带你入门阅读全文>
正文共996个字,10张图,预计阅读时间15分钟。
原理介绍
图片1
图片2
图片3
更多资料:https://distill.pub/2016/augmented-rnns/#attentional-interfaces
https://www.cnblogs.com/shixiangwan/p/7573589.html#top
http://baijiahao.baidu.com/s?id=1587926245504773589&wfr=spider&for=pc
论文阅读
Hierarchical Attention Networks for Document Classification(http://www.aclweb.org/anthology/N16-1174)
这篇文章主要讲述了基于Attention机制实现文本分类
假设我们有很多新闻文档,这些文档属于三类:军事、体育、娱乐。其中有一个文档D有L个句子si(i代表s是文档D的第i个句子),每个句子包含Ti个词(word),wit代表第i个句子的word,t∈[0,T]
Word Encoder:
①给定一个句子si,例如 The superstar is walking in the street,由下面表示[wi1,wi2,wi3,wi4,wi5,wi6,wi1,wi7],我们使用一个词嵌入矩阵W将单词编码为向量
②使用双向GRU编码整个句子关于单词wit的隐含向量:
那么最终隐含向量为前向隐含向量和后向隐含向量拼接在一起
Word Attention:
给定一句话,并不是这个句子中所有的单词对个句子语义起同等大小的“贡献”,比如上句话“The”,“is”等,这些词没有太大作用,因此我们需要使用attention机制来提炼那些比较重要的单词,通过赋予权重以提高他们的重要性。
①通过一个MLP获取hit的隐含表示:
②通过一个softmax函数获取归一化的权重:
③计算句子向量:
通过每个单词获取的hit与对应权重αit乘积,然后获取获得句子向量
代码实现
1attenton.py2import tensorflow as tf3def attention(inputs, attention_size, time_major=False, return_alphas=False):4 if isinstance(inputs, tuple):5# In case of Bi-RNN, concatenate the forward and the backward RNN outputs.6inputs = tf.concat(inputs, 2)7if time_major:8# (T,B,D) => (B,T,D)9inputs = tf.array_ops.transpose(inputs, [1, 0, 2])10hidden_size = inputs.shape[2].value # D value - hidden size of the RNN layer11# Trainable parameters12w_omega = tf.Variable(tf.random_normal([hidden_size, attention_size], stddev=0.1))13b_omega = tf.Variable(tf.random_normal([attention_size], stddev=0.1))14u_omega = tf.Variable(tf.random_normal([attention_size], stddev=0.1))15with tf.name_scope('v'):16# Applying fully connected layer with non-linear activation to each of the B*T timestamps;17# the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size18v = tf.tanh(tf.tensordot(inputs, w_omega, axes=1) + b_omega)19# For each of the timestamps its vector of size A from `v` is reduced with `u` vector20vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape21alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape22# Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape23output = tf.reduce_sum(inputs * tf.expand_dims(alphas, -1), 1)24if not return_alphas:25return output26else:27return output, alphas28train.py29from __future__ import print_function, division30import numpy as np31import tensorflow as tf32from keras.datasets import imdb33from tensorflow.contrib.rnn import GRUCell34from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn as bi_rnn35from tqdm import tqdm36from attention import attention37from utils import get_vocabulary_size, fit_in_vocabulary, zero_pad, batch_generator38NUM_WORDS = 1000039INDEX_FROM = 340SEQUENCE_LENGTH = 25041EMBEDDING_DIM = 10042HIDDEN_SIZE = 15043ATTENTION_SIZE = 5044KEEP_PROB = 0.845BATCH_SIZE = 25646NUM_EPOCHS = 3 # Model easily overfits without pre-trained words embeddings, that's why train for a few epochs47 DELTA = 0.548 MODEL_PATH = './model'49# Load the data set50(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=NUM_WORDS, index_from=INDEX_FROM)51# Sequences pre-processing52vocabulary_size = get_vocabulary_size(X_train)53X_test = fit_in_vocabulary(X_test, vocabulary_size)54X_train = zero_pad(X_train, SEQUENCE_LENGTH)55X_test = zero_pad(X_test, SEQUENCE_LENGTH)56# Different placeholders57with tf.name_scope('Inputs'):58batch_ph = tf.placeholder(tf.int32, [None, SEQUENCE_LENGTH], name='batch_ph')59target_ph = tf.placeholder(tf.float32, [None], name='target_ph')60seq_len_ph = tf.placeholder(tf.int32, [None], name='seq_len_ph')61 keep_prob_ph = tf.placeholder(tf.float32, name='keep_prob_ph')62 # Embedding layer63 with tf.name_scope('Embedding_layer'):64 embeddings_var = tf.Variable(tf.random_uniform([vocabulary_size, EMBEDDING_DIM], -1.0, 1.0), trainable=True)65 tf.summary.histogram('embeddings_var', embeddings_var)66 batch_embedded = tf.nn.embedding_lookup(embeddings_var, batch_ph)67 # (Bi-)RNN layer(-s)68 rnn_outputs, _ = bi_rnn(GRUCell(HIDDEN_SIZE), GRUCell(HIDDEN_SIZE),69 inputs=batch_embedded, sequence_length=seq_len_ph, dtype=tf.float32)70 tf.summary.histogram('RNN_outputs', rnn_outputs)71 # Attention layer72 with tf.name_scope('Attention_layer'):73 attention_output, alphas = attention(rnn_outputs, ATTENTION_SIZE, return_alphas=True)74 tf.summary.histogram('alphas', alphas)75 # Dropout76 drop = tf.nn.dropout(attention_output, keep_prob_ph)77 # Fully connected layer78 with tf.name_scope('Fully_connected_layer'):79 W = tf.Variable(tf.truncated_normal([HIDDEN_SIZE * 2, 1], stddev=0.1)) # Hidden size is multiplied by 2 for Bi-RNN80 b = tf.Variable(tf.constant(0., shape=[1]))81 y_hat = tf.nn.xw_plus_b(drop, W, b)82 y_hat = tf.squeeze(y_hat)83 tf.summary.histogram('W', W)84 with tf.name_scope('Metrics'):85 # Cross-entropy loss and optimizer initialization86 loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_hat, labels=target_ph))87 tf.summary.scalar('loss', loss)88 optimizer = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(loss)89 # Accuracy metric90 accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(tf.sigmoid(y_hat)), target_ph), tf.float32))91 tf.summary.scalar('accuracy', accuracy)92 merged = tf.summary.merge_all()93 # Batch generators94 train_batch_generator = batch_generator(X_train, y_train, BATCH_SIZE)95 test_batch_generator = batch_generator(X_test, y_test, BATCH_SIZE)96 train_writer = tf.summary.FileWriter('./logdir/train', accuracy.graph)97 test_writer = tf.summary.FileWriter('./logdir/test', accuracy.graph)98 session_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))99 saver = tf.train.Saver()100 if __name__ == "__main__":101 with tf.Session(config=session_conf) as sess:102 sess.run(tf.global_variables_initializer())103 print("Start learning...")104 for epoch in range(NUM_EPOCHS):105 loss_train = 0106 loss_test = 0107 accuracy_train = 0108 accuracy_test = 0109 print("epoch: {}\t".format(epoch), end="")110 # Training111 num_batches = X_train.shape[0] // BATCH_SIZE112 for b in tqdm(range(num_batches)):113 x_batch, y_batch = next(train_batch_generator)114 seq_len = np.array([list(x).index(0) + 1 for x in x_batch]) # actual lengths of sequences115 loss_tr, acc, _, summary = sess.run([loss, accuracy, optimizer, merged],116 feed_dict={batch_ph: x_batch, target_ph: y_batch, seq_len_ph: seq_len, keep_prob_ph: KEEP_PROB})117accuracy_train += acc118loss_train = loss_tr * DELTA + loss_train * (1 - DELTA)119train_writer.add_summary(summary, b + num_batches * epoch)120accuracy_train /= num_batches121 # Testing122 num_batches = X_test.shape[0] // BATCH_SIZE123 for b in tqdm(range(num_batches)):124 x_batch, y_batch = next(test_batch_generator)125 seq_len = np.array([list(x).index(0) + 1 for x in x_batch]) # actual lengths of sequences126 loss_test_batch, acc, summary = sess.run([loss, accuracy, merged], feed_dict={batch_ph: x_batch, target_ph: y_batch, seq_len_ph: seq_len, keep_prob_ph: 1.0})127 accuracy_test += acc128 loss_test += loss_test_batch129 test_writer.add_summary(summary, b + num_batches * epoch)130 accuracy_test /= num_batches131 loss_test /= num_batches132 print("loss: {:.3f}, val_loss: {:.3f}, acc: {:.3f}, val_acc: {:.3f}".format(133 loss_train, loss_test, accuracy_train, accuracy_test134 ))135train_writer.close()136test_writer.close()137saver.save(sess, MODEL_PATH)138print("Run 'tensorboard --logdir=./logdir' to checkout tensorboard logs.")139utils.py140from __future__ import print_function141import numpy as np142def zero_pad(X, seq_len):143return np.array([x[:seq_len - 1] + [0] * max(seq_len - len(x), 1) for x in X])144 def get_vocabulary_size(X):145 return max([max(x) for x in X]) + 1 # plus the 0th word146 def fit_in_vocabulary(X, voc_size):147 return [[w for w in x if w < voc_size] for x in X]148 def batch_generator(X, y, batch_size):149 """Primitive batch generator 150 """151 size = X.shape[0]152 X_copy = X.copy()153 y_copy = y.copy()154 indices = np.arange(size)155 np.random.shuffle(indices)156 X_copy = X_copy[indices]157 y_copy = y_copy[indices]158 i = 0159 while True:160if i + batch_size <= size:161 yield X_copy[i:i + batch_size], y_copy[i:i + batch_size]162 i += batch_size163else:164 i = 0165 indices = np.arange(size)166 np.random.shuffle(indices)167 X_copy = X_copy[indices]168 y_copy = y_copy[indices]169 continue170if __name__ == "__main__":171# Test batch generator172gen = batch_generator(np.array(['a', 'b', 'c', 'd']), np.array([1, 2, 3, 4]), 2)173for _ in range(8):174xx, yy = next(gen)175print(xx, yy)
代码地址:https://github.com/ilivans/tf-rnn-attention
运行结果
在训练集上准确率达到96%,测试集达到86%,效果还是很强大。
原文链接:https://www.jianshu.com/p/cc6407444a8c
查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:
www.leadai.org
请关注人工智能LeadAI公众号,查看更多专业文章
大家都在看
LSTM模型在问答系统中的应用
基于TensorFlow的神经网络解决用户流失概览问题
最全常见算法工程师面试题目整理(一)
最全常见算法工程师面试题目整理(二)
TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络
装饰器 | Python高级编程
今天不如来复习下Python基础
TensorFlow实现Attention机制相关推荐
- TensorFlow LSTM 注意力机制图解
TensorFlow LSTM Attention 机制图解 深度学习的最新趋势是注意力机制.在接受采访时,现任OpenAI研究主管的Ilya Sutskever提到,注意力机制是最令人兴奋的进步之一 ...
- 【TensorFlow实战笔记】对于TED(en-zh)数据集进行Seq2Seq模型实战,以及对应的Attention机制(tf保存模型读取模型)
个人公众号 AI蜗牛车 作者是南京985AI硕士,CSDN博客专家,研究方向主要是时空序列预测和时间序列数据挖掘,获国家奖学金,校十佳大学生,省优秀毕业生,阿里天池时空序列比赛rank3.公众号致力于 ...
- Tensorflow (6) Attention 注意力机制
参考: 细讲 | Attention Is All You Need 关于注意力机制(<Attention is all you need>) 一步步解析Attention is All ...
- seq2seq与Attention机制
学习目标 目标 掌握seq2seq模型特点 掌握集束搜索方式 掌握BLEU评估方法 掌握Attention机制 应用 应用Keras实现seq2seq对日期格式的翻译 4.3.1 seq2seq se ...
- [深度学习] 自然语言处理 --- 基于Attention机制的Bi-LSTM文本分类
Peng Zhou等发表在ACL2016的一篇论文<Attention-Based Bidirectional Long Short-Term Memory Networks for Relat ...
- Glove与Attention机制资料的整理
前言 2021.7.31 学习是一个持续的过程,重新梳理一下自己的文章.突然发现这篇文章好像是之前组会的时候准备汇报资料学习的参考文献2333.真的很推荐去看. 1 Glove 论文出处:<&l ...
- python attention机制_[深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心)...
[深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心) 笔者在[深度概念]·Attention机制概念学习笔记博文中,讲解了Attention机制的概念与技术细节,本篇内 ...
- 《自然语言处理学习之路》15 Seq2Seq、Attention机制
书山有路勤为径,学海无涯苦作舟 黑发不知勤学早,白首反悔读书迟. 1. Sequence-to-Sequence(N to M) 1.1 简介 先编码,再解码.STAR开始解码,END终止解码. EN ...
- 基于Attention机制的BiLSTM语音情感识别研究与系统实现
1.摘要 以往的情感分类大多是基于粗粒度进行的,针对七分类情感语料进行的研究不多,且最终的情感分类结果只包含一种情感,很少对多情感共存现象进行研究,因此不能完全体现用户情感的丰富性. 针对这些不足,本 ...
最新文章
- PyramidBox笔记
- WildFly 报错 java.lang.NoClassDefFoundError
- 【VirtualBox】NAT模式下主机访问客机的设置
- rep movsd + rep movsb 内联实现 strcpy
- Hadoop三大核心组件及需求催生大数据技术的背景
- python嵌入shell代码_大家一起学python-Python基础1
- day12【过渡】SpringCloud
- 深度 | Google Brain研究工程师:为什么随机性对于深度学习如此重要?
- 爬虫爬当当网书籍信息
- BiLSTM-CRF模型理解
- elementUI之模拟goTop组件
- c语言编写4个子函数用主函数调用,哪位师傅知道51单片机怎样编写子程序?C语言的。在主程序里调...
- 如何创建一个“个人微信公众号”
- 错误代码:DNS_PROBE_FINISHED_NXDOMAIN解决办法
- 太极拳经验谈 --- 董英杰
- R语言使用mgcv包中的gam函数拟合广义加性模型(Generalized Additive Model,GAMs):从广义加性模型GAM中抽取学习到的样条函数(spline function)
- 以IP地址的形式访问网站
- Linux Framebuffer驱动剖析之一—软件需求
- 解决ifconfig命令下只出现lo本地环回的问题
- Day11:文件和异常
热门文章
- 为什么使用HashMap需要重写hashcode和equals方法_不同时重写equals和hashCode又会怎样?听听过来人的经验...
- 【linux】RedHat 7.x 升级 openssh 为 8.x 版本
- android动态壁纸提取,[图]大神已提取出一加8T的动态壁纸:Android 8.0+设备均可使用...
- oracle无法远程安装,docker部署Oracle,无法远程连接(已解决)
- python作品讲解_python实例作品
- SpringBoot | 第六章:常用注解介绍及简单使用
- 关于height、offsetheight、clientheight、scrollheight、innerheight、outerheight的区别
- 数据结构 | 链表:1074
- js文件,同样的路径,拷贝过来的为什么不能访问
- 浅谈对机器学习方法(决策树,SVM,knn最近邻,随机森林,朴素贝叶斯、逻辑回归)的理解以及用sklearn工具实现文本分类和回归方法...