深度学习多分类案例:新闻文本分类

公众号:机器学习杂货店
作者:Peter
编辑:Peter

大家好,我是Peter~

这里是机器学习杂货店 Machine Learning Grocery~

之前介绍过一个单分类的问题。当每个数据点可以划分到多个类别、多个标签下,这就是属于多分类问题了。

本文介绍一个基于深度学习的多分类实战案例:新闻文本分类,最终是有46个不同的类别

数据集

路透社数据集

广泛使用的文本分类数据集:46个不同的主题,即输出有46个类别。某些样本的主题更多,但是训练集中的每个主题至少有10个样本

加载数据集

也是内置的数据集

In [1]:

from keras.datasets import reuters

In [2]:

# 限制前10000个最常见的单词(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)

In [3]:

# 查看数据
len(train_data)

Out[3]:

8982

In [4]:

len(test_data)

Out[4]:

2246

In [5]:

train_data.shape

Out[5]:

(8982,)

In [6]:

type(train_data)

Out[6]:

numpy.ndarray

索引解码为单词

In [7]:

word_index = reuters.get_word_index()reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

In [8]:

decoded_newswire = ' '.join([reverse_word_index.get(i - 3, '?') for i in train_data[0]])
decoded_newswire

Out[8]:

'? ? ? said as a result of its december acquisition of space co it expects earnings per share in 1987 of 1 15 to 1 30 dlrs per share up from 70 cts in 1986 the company said pretax net should rise to nine to 10 mln dlrs from six mln dlrs in 1986 and rental operation revenues to 19 to 22 mln dlrs from 12 5 mln dlrs it said cash flow per share this year should be 2 50 to three dlrs reuter 3'

索引减去3是因为:012分别是为“padding填充”,“start of sequence(序列开始)”,"unknown(未知)"分别保留的索引

样本标签对应的是0-45范围内的整数:

In [9]:

train_labels[10]

Out[9]:

3

数据向量化

In [10]:

import numpy as npdef vectorize_sequences(sequences, dimension=10000):results = np.zeros((len(sequences), dimension)) # 创建全0矩阵for i, sequence in enumerate(sequences):  results[i, sequence] = 1.  # 指定位置填充1return results# 训练数据和测试数据向量化
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

标签向量化-onehot

主要是有两种方法:

  • 将标签列表转成整数张量
  • one-hot编码,分类编码的一种

In [11]:

import numpy as npdef to_one_hot(labels, dimension=46):results = np.zeros((len(labels), dimension)) # 创建全0矩阵for i, label in enumerate(labels):  results[i, label] = 1.  # 指定位置填充1return results# 训练标签和测试标签向量化
one_hot_train_label = to_one_hot(train_labels)
one_hot_test_label = to_one_hot(test_labels)

In [12]:

np.zeros((4,5))

Out[12]:

array([[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]])

Keras内置方法实现one-hot

In [13]:

from keras.utils.np_utils import to_categoricalone_hot_train_labels = to_categorical(train_labels)
one_hot_test_labels = to_categorical(test_labels)

建模

模型定义(修改)

In [14]:

import tensorflow as tf  # add
from keras import models
from keras import layers model = models.Sequential()# 原文model.add(layers.Dense(16, activation="relu", input_shape=(10000, )))
# 统一修改3处内容:layers.Dense 变成 tf.keras.layers.Dense
model.add(tf.keras.layers.Dense(64,activation="relu",input_shape=(10000,)))
model.add(tf.keras.layers.Dense(64,activation="relu"))
model.add(tf.keras.layers.Dense(46,activation="softmax"))

注意两点:

  1. 网络的最后一层是大小为46的Dense层。意味着,对于每个输入样本,网络都会输出一个46维的向量,这个向量的每个元素代表不同的输出类型
  2. 最后一个使用的是softmax激活:网络将输出在46个不同类别上的概率分布,output[i]是样本属于第i个类别的概率,46个概率的总和是1

模型编译

多分类问题最好使用categorical_crossentropy作为损失函数。它用于衡量两个概率分布之间的距离:网络输出的概率分布和标签的真实概率分布

目标:这两个概率分布的距离最小化

In [15]:

model.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"])

验证模型-提取验证集

In [16]:

# 取出1000个样本作为验证集x_val = x_train[:1000]
partial_x_train = x_train[1000:]y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]

训练网络

开始20个轮次epochs训练网络

In [17]:

history = model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=512,validation_data =(x_val, y_val))

绘图

损失

In [20]:

# 损失绘图
import matplotlib.pyplot as plthistory_dict = history.history
loss_values = history_dict["loss"]
val_loss_values = history_dict["val_loss"]epochs = range(1,len(loss_values) + 1)  # 训练
plt.plot(epochs,  #  横坐标loss_values,  # 纵坐标"r",  # 颜色和形状,默认是实线label="Training_Loss"  # 标签名)
# 验证
plt.plot(epochs,val_loss_values,"b",label="Validation_Loss")plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()plt.show()

精度(修改)

多分类问题中得到的是acc的全拼,而不是缩写

In [19]:

# 精度绘图
import matplotlib.pyplot as plthistory_dict = history.history
acc_values = history_dict["accuracy"]  # 修改:原文是acc ---> accuracy
val_acc_values = history_dict["val_accuracy"]   # val_acc ---> val_accuracyepochs = range(1,len(loss_values) + 1)plt.plot(epochs,acc_values,"r",label="Training_ACC")plt.plot(epochs,val_acc_values,"b",label="Validation_ACC")plt.title("Training and Validation ACC")
plt.xlabel("Epochs")
plt.ylabel("acc")
plt.legend()plt.show()

重新训练

网路在训练9轮后开始过拟合,重新训练一个新网络:共9个轮次

In [25]:

import tensorflow as tf  # add
from keras import models
from keras import layers model = models.Sequential()
model.add(tf.keras.layers.Dense(64,activation="relu",input_shape=(10000, )))model.add(tf.keras.layers.Dense(64,activation="relu"))model.add(tf.keras.layers.Dense(46,activation="softmax"))model.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"])model.fit(partial_x_train,partial_y_train,epochs=9,batch_size=512,validation_data=(x_val, y_val))results = model.evaluate(x_test, one_hot_test_labels)
results

这个模型的精度到达了78.6%。如果是随机的基准是多少呢?

In [26]:

import copy test_labels_copy = copy.copy(test_labels)
np.random.shuffle(test_labels_copy)

In [28]:

hist_array = np.array(test_labels) == np.array(test_labels_copy)

In [29]:

hist_array  # T或者F的元素

Out[29]:

array([False, False, False, ..., False,  True, False])

In [30]:

float(np.sum(hist_array)) / len(test_labels)

Out[30]:

0.18744434550311664

测试集验证

使用的是predict函数

In [31]:

predictions = model.predict(x_test)
predictions

Out[31]:

array([[1.35964816e-04, 5.63261092e-05, 1.82070780e-05, ...,1.30567923e-06, 1.98109021e-07, 6.60357784e-07],[4.93899640e-03, 4.66015842e-03, 1.14109996e-03, ...,1.06564505e-04, 3.91121466e-05, 9.38424782e-04],[1.59738748e-03, 7.79436469e-01, 1.73478038e-03, ...,1.25069331e-04, 1.86533769e-04, 1.70501284e-04],...,[6.38374695e-05, 1.36295348e-04, 4.94179221e-05, ...,2.28828794e-05, 2.12488590e-06, 5.73998386e-06],[3.28431278e-03, 6.02599606e-02, 3.11069656e-03, ...,6.78190409e-05, 8.90388037e-05, 2.43265240e-04],[1.19820972e-04, 6.49809241e-01, 1.08765960e-02, ...,7.17515213e-05, 3.85396503e-04, 3.51801835e-04]], dtype=float32)

1、predictions中每个元素都是46维的向量:

In [32]:

# 每个元素都是46维的向量predictions[0].shape

Out[32]:

(46,)

In [33]:

predictions[1].shape

Out[33]:

(46,)

In [34]:

predictions[50].shape

Out[34]:

(46,)

2、所有元素的和为1:

In [35]:

# 所有元素的和为1
sum(predictions[0])

Out[35]:

1.0000001240543241

3、最大元素就是预测的类别,也就是概率最大的类别:

In [36]:

np.argmax(predictions[0])

Out[36]:

3

In [37]:

np.argmax(predictions[4])

Out[37]:

13

换种方式处理标签和损失

In [38]:

# 方式1:转换为整数张量y_train = np.array(train_labels)
y_test = np.array(test_labels)

使用的损失函数categorical_crossentropy,标签遵循分类编码。

如果是整数标签,使用sparse_categorical_crossentropy:

In [39]:

model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",  # 损失函数metrics=["accuracy"])

中间层维度足够大的重要性

最终输出是46维的,因此中间层的隐藏单个数不应该比46小太多。如果小太多,将会造成信息的丢失:

In [40]:

import tensorflow as tf
from keras import models
from keras import layers model = models.Sequential()
model.add(tf.keras.layers.Dense(64,activation="relu",input_shape=(10000, )))model.add(tf.keras.layers.Dense(4,  # 中间层从64---->4activation="relu"))model.add(tf.keras.layers.Dense(46,activation="softmax"))model.compile(optimizer="rmsprop",loss="categorical_crossentropy",metrics=["accuracy"])model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=512,validation_data=(x_val, y_val))

我们发现最终上升到了69.2%

进一步实验

  1. 尝试使用更多或者更少的隐藏单元,比如32或者128等
  2. 改变隐藏层个数,目前是2个;可以改成1个或者3个

小结

  1. 如果是对N个类别进行分类,最后一层应该是大小为N的Dense层
  2. 单标签多分类问题,网络的最后一层使用softmax激活,输出在N个输出类别上的概率分布
  3. 损失函数几乎都是分类交叉熵categorical_crossentropy。它将网络输出的概率分布和目标真实分布之间的距离最小化
  4. 避免使用太小的中间层,以免在网络中造成信息瓶颈。
  5. 处理多分类的标签方法:
    • 分类编码:one-hot编码,然后使用categorical_crossentropy
    • 将标签编码为整数,然后使用sparse_categorical_crossentropy

深度学习实战案例:新闻文本分类相关推荐

  1. 机器学习_深度学习毕设题目汇总——文本分类

    下面是该类的一些题目:| 题目 | |–| |基于主题特征的多标签文本分类方法研究| |融合全局和局部特征的文本分类方法研究| |BiGRU-CapsNet文本分类模型研究| |基于Attentio ...

  2. 深度学习实战案例:电影评论二分类

    第一个深度学习实战案例:电影评论分类 公众号:机器学习杂货店 作者:Peter 编辑:Peter 大家好,我是Peter~ 这里是机器学习杂货店 Machine Learning Grocery~ 本 ...

  3. bert使用做文本分类_使用BERT进行深度学习的多类文本分类

    bert使用做文本分类 Most of the researchers submit their research papers to academic conference because its ...

  4. 【第 07 章 基于主成分分析的人脸二维码识别MATLAB深度学习实战案例】

    基于主成分分析的人脸二维码识别MATLAB深度学习实战案例 人脸库 全套文件资料目录下载链接–>传送门 本文全文源码下载[链接–>传送门] 如下分析: 主文件 function varar ...

  5. 【深度学习前沿应用】文本分类Fine-Tunning

    [深度学习前沿应用]文本分类Fine-Tunning 作者简介:在校大学生一枚,华为云享专家,阿里云星级博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家 ...

  6. 第 09 章 基于特征匹配的英文印刷字符识别 MATLAB深度学习实战案例

    基于特征匹配的英文印刷字符识别 MATLAB深度学习实战 话不多讲,直接开撸代码 MainForm函数 function MainForm global bw; global bl; global b ...

  7. 深度学习实战案例:基于LSTM的四种方法进行电影评论情感分类预测(附完整代码)

    序列分类是一个预测建模问题,你有一些输入序列,任务是预测序列的类别. 这个问题很困难,因为序列的长度可能不同,包含非常大的输入符号词汇表,并且可能需要模型学习输入序列中符号之间的长期上下文或依赖关系. ...

  8. NLP实战-中文新闻文本分类

    目录 1.思路 2.基于paddle的ERINE模型进行迁移学习训练 3.分步实现 3.1 获取数据 (1)数据解压 (2)将文本转成变量,这里为了好计算,我只选了新闻标题做文本分类 3.2 中文分词 ...

  9. 【新闻文本分类】(task4)使用gensim训练word2vec

    学习总结 (1)学习训练Word2Vec 词向量,为后面task搭建 TextCNN 模型.BILSTM 模型训练预测作准备.Word2vec 的研究中提出的模型结构.目标函数.负采样方法.负采样中的 ...

  10. 深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

    文章目录 一.前期工作 1. 设置GPU 2. 导入预处理词库类 二.导入预处理词库类 三.参数设定 四.创建模型 五.训练模型函数 六.测试模型函数 七.训练模型与预测 今天给大家带来一个简单的中文 ...

最新文章

  1. 小型星形网络结构设计示例
  2. 1076 Forwards on Weibo
  3. 全网最全的Postman接口自动化测试(小鸟成大鸟级攻略)
  4. 5.分布式数据库HBase第1部分
  5. ONVIF网络摄像头(IPC)客户端开发—RTSP RTCP RTP加载H264视频流
  6. Web框架之Django_03 路由层了解(路有层 无名分组、有名分组、反向解析、路由分发 视图层 JsonResponse,FBV、CBV、文件上传)
  7. linux ubuntu ssh,Linux(Ubuntu)安装ssh服务
  8. 知识图谱 图数据库 推理_图数据库的知识表示与推理
  9. java列表框_Java图形用户界面之列表框
  10. 漫谈边缘计算(二):各怀心事的玩家
  11. 盒子模型(悬挂式布局)
  12. 以软件开发周期来说明不同的测试的使用情况
  13. 通俗地告诉你:为什么Dijkstra算法是正确的?
  14. office+visio2016版本一同安装说明
  15. 六西格玛dfss_六西格玛设计DFSS.pdf
  16. python学习(五)--打印错误信息
  17. 分享一个挺不错的Git视频教程
  18. oracle表空间配额(quota)与UNLIMITED TABLESPACE系统权限
  19. BES(恒玄) 平台 复杂按键 实现
  20. linux 五种 IO 模型

热门文章

  1. 基于树莓派的智能图像识别垃圾分类系统
  2. 有了5G手机和套餐,如何正确使用5G网络?
  3. L298N、电机、单片机的线路连接(51、stm32程序)
  4. 计算机管理丢失computer文件,Win7弹框提示找不到Computer Management.lnk文件怎么办?...
  5. OriginPro 绘制柱状图(特别是用于对比实验时)
  6. 百度 嵌入式Linux软件研发工程师面试记录
  7. 生成mysql.sock_mysql.sock不在了,怎么手工创建一个,并设置相应属性
  8. 发票自动处理识别和分类
  9. 效度不达标的处理方式
  10. Mac如何重装系统?