基于Kears的Reuters新闻分类
Reuters数据集下载速度慢,可以在我的repo库中找到下载,下载后放到~/.keras/datasets/目录下,即可正常运行。
构建神经网络将路透社新闻分类,一共有46个类别。因为有多个类别,属于多分类问题,而每条数据只属于一个类别,所以是单标签多分类问题;如果每条数据可以被分到多个类别中,那问题则属于多标签多分类问题。
完整代码 欢迎Fork、Star
路透社数据集
Reuters数据集发布在1986年,一系列短新闻及对应话题的数据集;是文本分类问题最常用的小数据集。和IMDB、MNIST数据集类似,Reuters数据集也可以通过Keras直接下载。
加载数据集
from keras.datasets import reuters(train_data,train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)
有8982条训练集,2246条测试集。
每个样本表示成整数列表。
>>> train_data[10]
[1, 245, 273, 207, 156, 53, 74, 160, 26, 14, 46, 296, 26, 39, 74, 2979,
3554, 14, 46, 4689, 4329, 86, 61, 3499, 4795, 14, 61, 451, 4329, 17, 12]
也可以将整数列表转换成原始数据[英文句子]
word_index = reuters.get_word_index()# 单词--下标 对应字典
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])# 下标-单词对应字典decoded_newswire = ' '.join([reverse_word_index.get(i - 3, '?') for i in
train_data[0]]) #偏移3个:0,1,2保留下标,分别表示:“padding,” “start of sequence,” and “unknown.”
准备数据
整数数据向量化,与IMDB数据集处理方法相同。
import numpy as npdef vectorize_sequences(sequences, dimension=10000):results = np.zeros((len(sequences), dimension))for i, sequence in enumerate(sequences):results[i, sequence] = 1.return resultsx_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
标签的向量化有两种方法:将标签列表转换成整数张量;使用one-hot编码。One-hot编码方式是类别数据常用的一种数据格式,也称为categorical encoding。
def to_one_hot(labels, dimension=46):# 46个类别results = np.zeros((len(labels), dimension))for i, label in enumerate(labels):results[i, label] = 1.return resultsone_hot_train_labels = to_one_hot(train_labels)
one_hot_test_labels = to_one_hot(test_labels)
Keras中有一个内置的One-hot编码转换函数:
from keras.utils.np_utils import to_categoricalone_hot_train_labels = to_categorical(train_labels)
one_hot_test_labels = to_categorical(test_labels)
模型搭建
使用Dense线性连接堆栈结构,每层网络只能处理上层网络的输出结果。如果网络层丢失了一些关于分类问题的信息,那么下一层网络并不能恢复这些信息:每个网络层潜在地成为一个信息处理瓶颈。
网络定义
from keras import models
from keras import layersmodel = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(46, activation='softmax'))
关于这个网络架构有两点需要注意:
- 最后一层网络神经元数目为46.意味着每个输入样本最终变成46维的向量。输出向量的每个数表示不同的类别;
- 最后一层网络使用softmax激活函数–网络会输出一个46类的概率分布。每个输入最终都会产生一个46维的向量,每个数表示属于该类别的概率,46个数加起来等于1.
最好的损失函数为categorical_crossentropy—衡量两个概率分布之间的距离:网络的输出向量和标签的真实分布向量。通过最小化两个分布之间的距离,训练网络模型,使得输出向量尽可能与真实分布相似。
model.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
模型验证
在训练数据中分出1000条样本做为验证集。
x_val = x_train[:1000]
partial_x_train = x_train[1000:]y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]
训练20个epochs
history = model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=512,validation_data=(x_val, y_val))
训练集和验证集的损失值变化
训练集和验证集的准确率变化
模型在第9次epochs之后开始过拟合。我们将epochs设置为5重新训练,同时在测试集上测试。
model = models.Sequential()model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(46, activation='softmax'))model.compile(optimizer='rmsprop',loss='categorical_crossentropy',
metrics=['accuracy'])model.fit(partial_x_train,partial_y_train,epochs=9,batch_size=512,
validation_data=(x_val, y_val))results = model.evaluate(x_test, one_hot_test_labels)
# [0.9565213431445807, 0.79697239536954589] 损失值,准确率
准确率达到80%.比随机猜测好。
预测新数据
使用predict函数,产生一个46维的概率分布。在测试数据上进行预测:
predictions = model.predict(x_test)
在预测结果中概率最大的类别就是预测类:
np.argmax(predictions[0])#第一条新闻的预测类 4
另一种标签、损失函数处理方式
直接将列表转换成numpy数组
y_train = np.array(train_labels)
y_test = np.array(test_labels)
需要改变的是损失函数的选择。categorical_crossentropy损失函数期望标签数据使用categorical encoding编码方式。整数标签,应该使用sparse_categorical_crossentropy损失函数:
model.compile(optimizer='rmsprop',loss='sparse_categorical_crossentropy',metrics=['acc'])
新的损失函数在数学表示上与categorical_crossentropy损失函数相同,只是接口不同。
有充分大规模中间层的重要性
因为最终分为46类,中间层的神经元数目不应该小于46个。如果中间层数目小于46,有4个,将会产生信息瓶颈。
model = models.Sequential()model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(4, activation='relu'))
model.add(layers.Dense(46, activation='softmax'))
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=128,validation_data=(x_val, y_val))
最终训练结果最高为71%,降低了8个百分点。主要原始是模型试图将大量的信息压缩到低纬度空间中表示,丢失了大量重要的信息。
小结
- N分类问题,网络最后Dense层神经元数目为N;
- 单标签多分类问题中,最后一层的激活函数为softmax,产生一个包含N类的概率分布;
- categorical crossentropy是处理单标签多分类问题最常用的损失函数;
- 在多分类问题中有两种标签处理方式:
- 使用categorical encoding(one-hot)编码,将标签one-hot化,同时使用categorical_crossentropy作为损失函数;
- 编码成整数向量,使用sparse_categorical_crossentropy作为损失函数;
- 如果分类数目过大,应该避免网络中间层数目过小(比分类数目小–信息压缩),产生信息瓶颈。
基于Kears的Reuters新闻分类相关推荐
- 基于PaddleNLP的真假新闻分类(二)Skep模型
一.基于PaddleNLP的美国大选的新闻真假分类(二)基于SKEP模型的分类任务 0.解释 本来这个都烂尾了,看到有人问二在哪儿?只好说还没公开,自己挖的坑,含泪也要填.下次标题再也不屑一.二了,真 ...
- 【深度学习kears+tensorflow】新闻分类:多分类问题
目录 Classifying newswires: a multi-class classification example 新闻分类:多分类问题 The Reuters dataset 路透社数据集 ...
- 基于LSTM模型实现新闻分类
1.简述LSTM模型 LSTM是长短期记忆神经网络,根据论文检索数据大部分应用于分类.机器翻译.情感识别等场景,在文本中,主要使用tensorflow及keras,搭建LSTM模型实现新闻分类案例.( ...
- Pytorch实战——基于RNN的新闻分类
目录 一.项目介绍 二.基于RNN的新闻分类 Step1 加载数据集 Step2 分词和构建词汇表 Step3 构建数据加载器 dataloader Step4 定义神经网络模型 Step5 定义模型 ...
- 基于 LSTM-Attention 的中文新闻文本分类
1.摘 要 经典的 LSTM 分类模型,一种是利用 LSTM 最后时刻的输出作为高一级的表示,而另一种是将所有时刻的LSTM 输出求平均作为高一级的表示.这两种表示都存在一定的缺陷,第一种缺失了前面的 ...
- 【Pytorch基础教程36】基于Ernie预训练模型和Bert的新闻分类
文章目录 一.新闻分类任务 1.1 中文数据集 1.2 数据特点 1.3 跑起代码 二. 预训练语言模型ERNIE 2.1 ERNIE模型结构 2.2 bert模型结构 三.项目代码 1. bert模 ...
- 新闻文本分类--任务5 基于深度学习的文本分类2
Task5 基于深度学习的文本分类2 在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习 ...
- 什么?你学深度学习还不会新闻分类:多分类问题?我来手把手教你
目录 Classifying newswires: a multi-class classification example 新闻分类:多分类问题 The Reuters dataset 路透社数据集 ...
- AI大牛周明打造的轻量“孟子模型”开源!靠10亿参数冲上CLUE榜第三,可用于新闻分类、文案生成...
明敏 发自 凹非寺 量子位 报道 | 公众号 QbitAI 只用10亿参数就杀进中文自然语言理解CLUE榜单前三的孟子模型,现在开源了! 其打造团队澜舟科技-创新工场最新宣布,基于孟子通用模型,他们将 ...
- 文本基线怎样去掉_ICML 2020 | 基于类别描述的文本分类模型
论文标题: Description Based Text Classification with Reinforcement Learning 论文作者: Duo Chai, Wei Wu, Qing ...
最新文章
- oracle12c racpdb,Oracle 12C R2的CDB与PDB简单管理操作
- (建议收藏)万字长文,带你一文吃透 Linux 提权
- C语言 enum和typedef enum的区别
- 【✊基础不牢,地动山摇のC语言中static关键字✊】
- 【手把手教你Maven】构建过程
- 深夜遭粉丝质问!4天掌握新东方26年教育精华的方法,你怎么不早说?
- java传参数的方法_java中方法的参数传递机制
- CSS 单词折行 word-wrap属性
- 西北乱跑娃 --- python繁体字简体字互转第三方库
- nurbs曲线拟合程序_基于NURBS曲线拟合的shx字体优化
- Python识别同构数
- Java 常用技术栈 相关概念总结, 更新中...
- 链接mysql 504_phpMyAdmin错误代码:504 MySQL查询
- 内外兼修:程序员的成长之路+软技能 代码之外的生存指南
- 安防市场视频监控比重大 并有新商机
- 高校开学,小心钓鱼邮件趁火打劫
- win10电脑连接蓝牙请检查PIN并重新连接
- 6.2、C++的内联函数、函数重载、局部变量和全局变量
- 用C语言编写简易计算器
- microbit测试题
热门文章
- 为什么DataGridView不出现滚动条?它的ScrollBars属性我设置为Both了
- ubuntu安装Google输入法
- 深度学习损失函数 分类损失回归损失
- 用python语言实现人工智能猴子摘香蕉的问题_人工智能猴子香蕉问题
- C++第14周项目1 - 动物怎么叫
- Matlab 括号用法
- JAVA第一次授课心得_关于第一次java课的感想
- 【统计知识总结系列01】回归分析、抽样技术、方差分析以及非参数统计中的方差分解
- matlab中欧姆怎么表示,电阻的单位为欧姆,用符号()表示。 - 问答库
- 趋势客户端修改服务器地址,趋势杀毒软件服务器端更改ip