使用神经网络完成新闻分类
文章目录
- 使用神经网络完成新闻分类
- (1)问题描述:
- (2)模型原理:
- (3)实现过程:
- (4)实验小结:
使用神经网络完成新闻分类
(1)问题描述:
该数据集用于文本分类,包括大约20000个左右的新闻文档,均匀分为20个不同主题的新闻组集合,其中:
**训练集:**包括11314个新闻文档及其主题分类标签。训练数据在文件train目录下,训练新闻文档在train_texts.dat文件中,训练新闻文档标签在train_labels.txt文档中,编号为0~19,表示该文档分属的主题标号。
**测试集:**包括7532个新闻文档,标签并未给出。测试集文件在test目录下,测试集新闻文档在test_texts.dat文件中。
(2)模型原理:
LSTM
长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层。
参考:Tensorflow实战:LSTM原理及实现(详解)_m0_37917271的博客-CSDN博客
做以下简要总结:
如下图所示:
从图中可以看到,LSTM 提出了三个门(gate)的概念:input gate,forget gate,output gate。其实可以这样来理解,input gate 决定了对输入的数据做哪些处理,forget gate 决定了哪些知识被过滤掉,无需再继续传递,而 output gate 决定了哪些知识需要传递到下一个时间序列。
计算依据下列公式:
it=σ(Wiixt+bii+Whih(t−1)+bhi)ft=σ(Wifxt+bif+Whfh(t−1)+bhf)gt=tanh(Wigxt+big+Whgh(t−1)+bhg)ot=σ(Wioxt+bio+Whoh(t−1)+bho)ct=ft∗c(t−1)+it∗gtht=ot∗tanh(ct)\begin{array}{c} i_{t}=\sigma\left(W_{i i} x_{t}+b_{i i}+W_{h i} h_{(t-1)}+b_{h i}\right) \\ f_{t}=\sigma\left(W_{i f} x_{t}+b_{i f}+W_{h f} h_{(t-1)}+b_{h f}\right) \\ g_{t}=\tanh \left(W_{i g} x_{t}+b_{i g}+W_{h g} h_{(t-1)}+b_{h g}\right) \\ o_{t}=\sigma\left(W_{i o} x_{t}+b_{i o}+W_{h o} h_{(t-1)}+b_{h o}\right) \\ c_{t}=f_{t} * c_{(t-1)}+i_{t} * g_{t} \\ h_{t}=o_{t} * \tanh \left(c_{t}\right) \end{array} it=σ(Wiixt+bii+Whih(t−1)+bhi)ft=σ(Wifxt+bif+Whfh(t−1)+bhf)gt=tanh(Wigxt+big+Whgh(t−1)+bhg)ot=σ(Wioxt+bio+Whoh(t−1)+bho)ct=ft∗c(t−1)+it∗gtht=ot∗tanh(ct)
其中:
- iti_{t}it 是处理 input 的 input gate, 外面一个 sigmoid 函数处理,其中的输入是当前输入 xtx_{t}xt 和前一个 时间状态的输出 h(t−1)h_{(t-1)}h(t−1) 。所有的 bbb 都是偏置项。
- ftf_{t}ft 则是 forget gate 的操作,同样对当前输入 xtx_{t}xt 和前一个状态的输出 h(t−1)h_{(t-1)}h(t−1) 进行处理。
- gtg_{t}gt 也是 input gate 中的操作, 同样的输入,只是外面换成了 tanh\tanhtanh 函数。
- oto_{t}ot 是前面图中 output gate 中左下角的操作,操作方式和前面 iti_{t}it, 以及 ftf_{t}ft 一样。
- ctc_{t}ct 则是输出之一, 对 forget gate 的输出, input gate 的输出进行相加,然后作为当前时间序列的 一个隐状态向下传递。
- hth_{t}ht 同样是输出之一,对 前面的 ctc_{t}ct 做一个 tanh\mathrm{tanh}tanh 操作,然后和前面得到的 oto_{t}ot 进行相乘, hth_{t}ht 既向 下一个状态传递,也作为当前状态的输出。
(3)实现过程:
读入数据集:
读取文件train_texts.dat和test_texts.dat方式如下,以train_texts.dat为例,test_texts.dat读取方式相同;标签文件为正常txt文件,读取方式按照读取txt文件即可。
# 读入数据 file_name = 'train/train_texts.dat' with open(file_name, 'rb') as f:train_texts = pickle.load(f) file_name = 'test/test_texts.dat' with open(file_name, 'rb') as f:test_texts = pickle.load(f)train_labals = [] fl = open('train/train_labels.txt') for line in fl.readlines():train_labals.append(line)
特征提取:
因为每篇新闻都是由英文字符表示而成,因此需要首先提取每篇文档的特征,把每篇文档抽取为特征向量,这里我们选择提取文档的TF-IDF特征,即词频(TF)-逆文本频率(IDF)。
提取文档的TF-IDF特征可以通过sklearn. feature_extraction.text中的TfidfVectorizer来完成,具体实现代码如下:
# TFIDF向量化 vectorizer = TfidfVectorizer(max_features=10000) train_vector = vectorizer.fit_transform(train_texts) print(train_vector.shape) test_vector = vectorizer.transform(test_texts) print(test_vector.shape)
标签向量化:
将标签向量化有两种方法:可以将标签列表转换为整数张量,或者使用one-hot 编码。
one-hot 编码是分类数据广泛使用的一种格式,也叫分类编码(categorical encoding)。这里我们采用该种方法,标签的one-hot编码就是将每个标签表示为全零向量, 只有标签索引对应的元素为 1。代码实现如下:
from keras.utils import to_categorical one_hot_train_labels = to_categorical(train_labals)
拆分训练集和测试集:
测试集比例选择为0.2:
#拆分测试集与训练集 X_train, X_test, y_train, y_test = train_test_split(train_vector, one_hot_train_labels, test_size=0.2, random_state=0) x_test = X_test.toarray() partial_x_train = X_train.toarray()
构建网络:
这里先尝试最简单的3个全连接层的网络。
输出类别的数量为20个。输出空间的维度较大。 对于用过的 Dense 层的堆叠,每层只能访问上一层输出的信息。如果某一层丢失了与 分类问题相关的一些信息,那么这些信息无法被后面的层找回,也就是说,每一层都可能成为信息瓶颈。因此,设计网络如下所示:
# 定义Sequential类 model = models.Sequential() # 全连接层,128个节点 model.add(layers.Dense(128, activation='relu', input_shape=(10000,))) # 全连接层,64个节点 model.add(layers.Dense(64, activation='relu')) # 全连接层,得到输出 model.add(layers.Dense(20, activation='softmax')) # loss model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
网络的最后一层是大小为20 的 Dense 层。这意味着,对于每个输入样本,网络都会输出一个 20 维向量。这个向量的每个元素(即每个维度)代表不同的输出类别。
我们选择的损失函数是 categorical_crossentropy(分类交叉熵)。它用于衡量两个概率分布之间的距离,这里两个概率分布分别是网络输出的概率分布和标签的真实分 布。通过将这两个分布的距离最小化,训练网络可使输出结果尽可能接近真实标签。
验证:
现在开始训练网络,共 20 个轮次。
history = model.fit(partial_x_train,y_train,epochs=20,batch_size=512,validation_data=(x_test, y_test))
截取第20次的结果:
Epoch 20/20512/9051 [>.............................] - ETA: 0s - loss: 0.0032 - accuracy: 0.9980 1536/9051 [====>.........................] - ETA: 0s - loss: 0.0022 - accuracy: 0.9993 2560/9051 [=======>......................] - ETA: 0s - loss: 0.0028 - accuracy: 0.9992 3072/9051 [=========>....................] - ETA: 0s - loss: 0.0025 - accuracy: 0.9993 4096/9051 [============>.................] - ETA: 0s - loss: 0.0022 - accuracy: 0.9995 5120/9051 [===============>..............] - ETA: 0s - loss: 0.0024 - accuracy: 0.9992 6656/9051 [=====================>........] - ETA: 0s - loss: 0.0023 - accuracy: 0.9994 7680/9051 [========================>.....] - ETA: 0s - loss: 0.0021 - accuracy: 0.9995 8704/9051 [===========================>..] - ETA: 0s - loss: 0.0021 - accuracy: 0.9995 9051/9051 [==============================] - 1s 79us/sample - loss: 0.0022 - accuracy: 0.9994 - val_loss: 0.3894 - val_accuracy: 0.8975
loss曲线:
绘制loss曲线:
loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(loss) + 1) plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.show()
精度曲线:
绘制acc曲线:
plt.clf() acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend() plt.show()
重新训练:
分析可知,网络在训练 10 轮左右开始出现过拟合趋势。
因此,我们重新训练一个新网络,共 10 个轮次,然后在测试集上评估模型。
history = model.fit(partial_x_train,y_train,epochs=10,batch_size=512,validation_data=(x_test, y_test))results = model.evaluate(x_test, y_test)
结果输出:
将结果写入txt:
x_input = test_vector.toarray() predictions = model.predict(x_input) out_put = np.argmax(predictions,axis=1) np.savetxt('result.txt', out_put, fmt='%d', delimiter='\n')
(4)实验小结:
如果要对 N 个类别的数据点进行分类,网络的最后一层应该是大小为 N 的 Dense 层。对于单标签、多分类问题,网络的最后一层应该使用 softmax 激活,这样可以输出在 N个输出类别上的概率分布。
损失函数使用分类交叉熵,它可以将网络输出的概率分布与目标的真实分布之间的距离最小化。
处理多分类问题的标签有比较经典的有两种方法。
- 通过分类编码(也叫 one-hot 编码)对标签进行编码,然后使用 categorical_crossentropy 作为损失函数。
- 标签编码为整数,然后使用 sparse_categorical_crossentropy 损失函数。
使用神经网络完成新闻分类相关推荐
- [Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...
- 人工智能--基于循环神经网络的新闻话题分类
学习目标: 理解循环神经网络RNN的基本原理. 掌握利用循环神经网络进行文本分类的方法. 学习内容: 利用循环神经网络进行新闻话题分类的代码,设置Embedding的trainable=True,并调 ...
- [Pytorch系列-60]:循环神经网络 - 中文新闻文本分类详解-2-LSTM网络训练与评估代码详解
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...
- AI大牛周明打造的轻量“孟子模型”开源!靠10亿参数冲上CLUE榜第三,可用于新闻分类、文案生成...
明敏 发自 凹非寺 量子位 报道 | 公众号 QbitAI 只用10亿参数就杀进中文自然语言理解CLUE榜单前三的孟子模型,现在开源了! 其打造团队澜舟科技-创新工场最新宣布,基于孟子通用模型,他们将 ...
- 基于神经网络的文本分类(基于Pytorch实现)
<Convolutional Neural Networks for Sentence Classification> 作者:Yoon Kim 单位:New York University ...
- 基于卷积神经网络的句子分类模型【经典卷积分类附源码链接】
https://www.toutiao.com/a6680124799831769603/ 基于卷积神经网络的句子分类模型 题目: Convolutional Neural Networks for ...
- 论文阅读:Convolutional Neural Networks for Sentence Classification 卷积神经网络的句子分类
Convolutional Neural Networks for Sentence Classification 卷积神经网络的句子分类 目录 Convolutional Neural Networ ...
- 新闻分类(python深度学习——多分类问题)
记:新闻分类问题时多分类问题,与电影评论分类很类似又有一些差别,电影评论只有两个分类,而新闻分类有46个分类,所以在空间维度上有所增加,多分类问题的损失函数与二分类问题选择不同,最后一层使用的激活函数 ...
- 什么?你学深度学习还不会新闻分类:多分类问题?我来手把手教你
目录 Classifying newswires: a multi-class classification example 新闻分类:多分类问题 The Reuters dataset 路透社数据集 ...
最新文章
- UE5虚幻引擎5中的实时特效学习 Introduction to real time FX in Unreal Engine 5
- ajax 阻止默认提交,jQuery验证插件:在对ajax调用servlet时,submitHandler不会阻止默认提交-返回false无效...
- iOS UILabel加载html点击图片查看大图 附demo
- php7 返回值,7.6.4 函数返回值
- .sh是什么语言_为什么《山海经·中次二经》中,把“西王母”叫做“马腹”?...
- HDU-1796 How many integers can you find 容斥定理
- 大数据技术的理解误区
- html原生js请求
- Docker镜像使用详解
- abab的四字成语_以abab的四字成语
- 2008年最吸引眼球的10只股票
- 可能改变世界的13个“终结”(上)
- Android各版本分布
- Remove Double Negative(去除双重否定)
- chrome浏览器被360流氓捆绑,如何解决?
- 1000行代码实现定制形象送虎年祝福
- 如何把一个qmake的Ubuntu手机应用打包为一个snap应用
- PowerShell 学习笔记:压缩、解压缩文件
- FI--SAP财务成本知识库
- android here地图,兼容所有安卓设备 Here地图已放开限制
热门文章
- Tecnomatix Plant Simulation 14 学习之路(二)
- SpringBoot框架,使用Log4j2+Lombok引入日志的方法
- 【精】LintCode领扣算法问题答案:入门
- 从登陆界面学习TextInputLayout
- 中芯微随身WIFI破解实体SIM卡槽(不拆机,无需切卡密码)
- 关于PBR流程的小概念
- 解决java编译错误(程序包javax.servlet不存在javax.servlet.*)
- 智慧医院智慧医疗解决方案
- 数据仓库——在“啤酒与尿布”中挖掘
- PMP项目管理 新考纲概述