学习目标:

  1. 理解注意力机制的基本原理。
  2. 掌握利用注意力机制进行文本分类的方法。

学习内容:

利用循环神经网络和注意力机制进行新闻话题分类的代码,并调整网络结构,看是否可以提高识别率。


学习过程:

经过调整网络结构,得出下表结果,

optimizer

lr

batch_ size

trainable

epochs

验证集识别率

Adam

0.001

128

False

10

0.6055

RMSprop

0.001

128

False

10

0.6035

用了两个不同的优化器,利用循环神经网络和注意力机制进行文本分类的效果:

Adam

RMSprop


源码:

# In[1]: 读取新闻话题分类数据
import pandas as pd
df = pd.read_json(r'D:\Cadabra_tools002\course_data\News_Category_Dataset.json', lines=True)
df.head()# In[2]: 预处理,合并 "WORLDPOST"和"THE WORLDPOST"两种类别
df.category = df.category.map(lambda x:"WORLDPOST" if x == "THE WORLDPOST" else x)
categories = df.groupby('category')
print("total categories: ", categories.ngroups)
#print(categories.size()) # In[3]: 将单词进行标号
from keras.preprocessing import sequence
from keras.preprocessing.text import Tokenizer# 将标题和正文合并
df['text'] = df.headline + " " + df.short_description# 将单词进行标号
tokenizer = Tokenizer()
tokenizer.fit_on_texts(df.text)
X = tokenizer.texts_to_sequences(df.text)
df['words'] = X
#print(X[:10])#记录每条数据的单词数
df['word_length'] = df.words.apply(lambda i: len(i))#清除单词数不足5个的数据条目
df = df[df.word_length >= 5]
df.word_length.describe()# 把每个数据条目超过50个单词的部分去掉,
# 不足50个单词的补0,使得所有的条目具有相同的单词数量
maxlen = 50
X = list(sequence.pad_sequences(df.words, maxlen=maxlen))# In[4]: 将类别进行编号# 得到两个字典,可以从类别得到编号,或者从编号得到类别
categories = df.groupby('category').size().index.tolist()
category_int = {}
int_category = {}
for i, k in enumerate(categories):category_int.update({k:i})  # 类别 -> 编号int_category.update({i:k})  # 编号 -> 类别df['c2id'] = df['category'].apply(lambda x: category_int[x])# In[5]: 随机选取训练样本
import numpy as np
import keras.utils as utils
from sklearn.model_selection import train_test_splitX = np.array(X)
Y = utils.to_categorical(list(df.c2id))# 将数据分成两部分,随机取80%用于训练,20%用于测试
seed = 29 # 随机种子
x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=0.2, random_state=seed)# In[6]: 加载预先训练好的单词向量
EMBEDDING_DIM = 100
embeddings_index = {}
f = open(r'D:\Cadabra_tools002\course_data\glove.6B.100d.txt',errors='ignore') # 每个单词用100个数字的向量表示
for line in f:values = line.split()word = values[0]coefs = np.asarray(values[1:], dtype='float32')embeddings_index[word] = coefs
f.close()print('Total %s word vectors.' %len(embeddings_index))  # 399913# In[7]: 构造Embedding层,并用预训练好的单词向量初始化,注意该层不用训练
from keras.initializers import Constant
from keras.layers import Embeddingword_index = tokenizer.word_index
embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
for word, i in word_index.items():embedding_vector = embeddings_index.get(word)#根据单词挑选出对应向量if embedding_vector is not None:embedding_matrix[i] = embedding_vector# Embedding层的输入的单词编号最大为 len(word_index)=86627
# 一个句子有 maxlen=50 个单词,每个单词编码成 EMBEDDING_DIM=100 维的向量
# Embedding层的输入大小为 (batch_size, maxlen)
# Embedding层的输出大小为 (batch_size, maxlen, EMBEDDING_DIM)
embedding_layer = Embedding(len(word_index)+1, EMBEDDING_DIM, embeddings_initializer=Constant(embedding_matrix),input_length = maxlen,trainable=True #可以改成True)# In[]: 自注意力层的定义
from keras.layers import Layer
import keras.backend as K
class Self_Attention(Layer):def __init__(self, output_dim, **kwargs):# out_shape = (batch_size, time_steps, output_dim)self.output_dim = output_dimsuper(Self_Attention, self).__init__(**kwargs)def build(self, input_shape):# 为该层创建一个可训练的权重, 3个二维的矩阵(3,lstm_units,output_dim)# input_shape = (batch_size, time_steps, lstm_units)self.kernel = self.add_weight(name='kernel',shape=(3,input_shape[2], self.output_dim),initializer='uniform',trainable=True)super(Self_Attention, self).build(input_shape)  # 一定要在最后调用它def call(self, x):WQ = K.dot(x, self.kernel[0])WK = K.dot(x, self.kernel[1])WV = K.dot(x, self.kernel[2])# print("WQ.shape",WQ.shape)# print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape)QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1]))QK = QK / (64**0.5)QK = K.softmax(QK)# print("QK.shape",QK.shape)V = K.batch_dot(QK,WV)return Vdef compute_output_shape(self, input_shape):return (input_shape[0],input_shape[1],self.output_dim)# In[11]: LSTM+注意力机制
from keras.layers import Input,Dense,Permute,Flatten
from keras.layers import LSTM,multiply,add,Lambda
from keras.models import Model
import matplotlib.pyplot as plt
import keras.backend as Kinputs = Input(shape=(maxlen,))  # 一批句子作为输入,每个句子有maxlen=50个单词
inputs_embedding = embedding_layer(inputs) # 输出(batch_size, maxlen=50, EMBEDDING_DIM=100)# time_steps=maxlen=50
# INPUT_DIM=EMBEDDING_DIM=100
lstm_units = 32 #可改变单词向量的维度
# (batch_size, time_steps, INPUT_DIM) -> (batch_size, time_steps, lstm_units)
lstm_out = LSTM(lstm_units, return_sequences=True)(inputs_embedding)
#LSTM val_acc: 0.5705lstm_out = Self_Attention(64)(lstm_out) # LSTM+Self_atten+sum: val_acc: 0.6117## ATTENTION PART STARTS ------------------ #
## (batch_size, time_steps, lstm_units) -> (batch_size, lstm_units, time_steps)
#a = Permute((2, 1))(lstm_out)
#
## 对最后一维进行全连接,参数数量:time_steps*time_steps + time_steps
## 相当于获得每一个step中,每个lstm维度在所有step中的权重
## (batch_size, lstm_units, time_steps) -> (batch_size, lstm_units, time_steps)
#a = Dense(maxlen, activation='softmax')(a)
#
## (batch_size, lstm_units, time_steps) -> (batch_size, time_steps, lstm_units)
#a_probs = Permute((2, 1), name='attention_vec')(a)
#
## 权重和输入的对应元素相乘,注意力加权,lstm_out=lstm_out*a_probs
#lstm_out = multiply([lstm_out, a_probs], name='attention_mul')
## ATTENTION PART FINISHES ---------------- # LSTM+atten+sum:0.6028#lstm_out = Permute((2, 1))(lstm_out)
lstm_out = Lambda(lambda X: K.sum(X,axis=1))(lstm_out)
#lstm_out = GlobalAveragePooling1D()(lstm_out)
#LSTM+sum val_acc: 0.6046# (batch_size, time_steps, lstm_units) -> (batch_size, time_steps*lstm_units)
output = Dense(len(int_category), activation='softmax')(lstm_out)
model = Model([inputs], output)
model.summary()model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
history = model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val), batch_size=128)
# val_acc: 0.5920# 绘制训练过程中识别率和损失的变化
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)plt.title('Training and validation accuracy')
plt.plot(epochs, acc, 'red', label='Training acc')
plt.plot(epochs, val_acc, 'blue', label='Validation acc')
plt.legend()
plt.show()

源码下载


学习产出:

  1. 调整一下参数,但经过多次迭代后识别率还是有所下降,识别率上不了70%;

人工智能--基于注意力机制的新闻话题分类相关推荐

  1. 人工智能--基于循环神经网络的新闻话题分类

    学习目标: 理解循环神经网络RNN的基本原理. 掌握利用循环神经网络进行文本分类的方法. 学习内容: 利用循环神经网络进行新闻话题分类的代码,设置Embedding的trainable=True,并调 ...

  2. 【推荐论文】基于多视角学习和个性化注意力机制的新闻推荐(附论文下载链接)...

    编者按:个性化新闻推荐是新闻行业必然的发展方向,在其实现过程中面临着三个关键问题,即分析用户兴趣.根据新闻内容建模和新闻排序.本文将这三个问题划归为新闻信息与用户兴趣的多样性问题,并由此出发,提出了基 ...

  3. 基于多视角学习和个性化注意力机制的新闻推荐(附论文下载链接)

    编者按:个性化新闻推荐是新闻行业必然的发展方向,在其实现过程中面临着三个关键问题,即分析用户兴趣.根据新闻内容建模和新闻排序.本文将这三个问题划归为新闻信息与用户兴趣的多样性问题,并由此出发,提出了基 ...

  4. 新年美食鉴赏——基于注意力机制CBAM的美食101分类

    新年美食鉴赏--基于注意力机制CBAM的美食101分类 一.数据预处理 1.数据集介绍 2.读取标签 3.统一命名 4.整理图片路径 5.划分训练集与验证集 6.定义美食数据集 二.注意力机制 1.简 ...

  5. 基于 LSTM-Attention 的中文新闻文本分类

    1.摘 要 经典的 LSTM 分类模型,一种是利用 LSTM 最后时刻的输出作为高一级的表示,而另一种是将所有时刻的LSTM 输出求平均作为高一级的表示.这两种表示都存在一定的缺陷,第一种缺失了前面的 ...

  6. 基于注意力机制的图卷积网络预测药物-疾病关联

    BIB | 基于注意力机制的图卷积网络预测药物-疾病关联 智能生信 人工智能×生物医药 ​关注 科学求真 赢 10 万奖金 · 院士面对面 9 人赞同了该文章 今天给大家介绍华中农业大学章文教授团队在 ...

  7. ciaodvd数据集的简单介绍_基于注意力机制的规范化矩阵分解推荐算法

    随着互联网技术的发展以及智能手机的普及, 信息超载问题也亟待解决.推荐系统[作为解决信息超载问题的有效工具, 已被成功应用于各个领域, 包括电子商务.电影.音乐和基于位置的服务等[.推荐系统通过分析用 ...

  8. 《Effective Approaches to Attention-based Neural Machine Translation》—— 基于注意力机制的有效神经机器翻译方法

    目录 <Effective Approaches to Attention-based Neural Machine Translation> 一.论文结构总览 二.论文背景知识 2.1 ...

  9. 【华为云技术分享】序列特征的处理方法之一:基于注意力机制方法

    [摘要] 本文介绍了针对序列特征采用的处理方法之一:基于注意力机制方法,并总结了一下相似性度量的方法. ▌前言 之前两篇讲过稠密特征和多值类别特征加入CTR预估模型的常用处理方法,这篇介绍一下针对序列 ...

最新文章

  1. 云计算(2)it 是什么
  2. 专访SIGDIAL2020最佳论文一作高信龙一:成功都是一步步走出来的
  3. Java知多少(43)异常处理基础
  4. 智能家居 (2) ——设计模式的引入
  5. Factory Method (工厂模式)
  6. 2021年中国电动卡车马达市场趋势报告、技术动态创新及2027年市场预测
  7. 原生JS大揭秘—原型链
  8. jq 目录树ajax,javascript
  9. 小米线刷包需要解压么_小米10刷机教程,线刷升级更新官方系统包
  10. 诺兰的阶段模型(转载)
  11. 计算机电源管理设置,如何修改计算机中设置的显卡电源管理模式
  12. android版本高低有啥好处与不好,WP跟安卓比流畅 但为什么就不好用呢?
  13. 织梦网站如何上传服务器还原,网站转移教程:织梦系统数据库备份和还原的方法步骤...
  14. SpringBoot项目启动Disconnected from the target VM
  15. KONGA配置KONG添加http-log插件
  16. java中的输入操作
  17. 22款受欢迎的计算机取证工具
  18. MATLAB画带延时系统的伯德图
  19. 2022最详细,最新的 Win11/WIN10 安装CUDA11.2和cuDNN(必坑之作)完美教程
  20. java农夫过河问题_农夫过河问题——C语言程序设计(转)

热门文章

  1. ATE工程师的进阶之路(LabVIEW方向)
  2. RocketMQ 专家丁威:Kafka 和 RocketMQ 从性能角度对比
  3. JSP转换Servlet
  4. JS实现数组每次只显示5条数据,首尾相连显示
  5. 剑网3 最新服务器,《剑网3》上海高校服务器今日开放
  6. css让几个快对象同时居中,多个CSS 居中方案,你可能还不知道!
  7. java getuserinfo_Java URI getRawUserInfo()用法及代码示例
  8. 计算机基础知识和实践技能300分,2019年河北省高职单招考试十类和对口电子电工类、计算机类联考职业适应性测试(计算机基础知识和实践技能)考试大纲Excel2010基本知识...
  9. 智慧地球危害中国国家安全
  10. php连接tidb,TiDB源码学习笔记:启动TiDB