Paddle2.0实现中文新闻文本标题分类

  • 中文新闻文本标题分类Paddle2.0版本基线(非官方)
      • 调优小建议
      • 数据集地址
    • 任务描述
      • 数据说明
      • 提交答案
    • 代码思路说明
      • 数据集解压
    • 数据处理
      • 数据读取(字典、数据集)
      • 数据初始化
      • 数据查看
      • 数据扩充
      • 数据封装
      • 网络定义
      • 模型训练
      • 推理数据读取
      • 开始推理
      • 作者简介

项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析
课程 传送门
该项目AiStudio项目 传送门
数据集 传送门

本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!

中文新闻文本标题分类Paddle2.0版本基线(非官方)

非官方,三岁出品!(虽水必精)

调优小建议

本项目基线的值不会很高,需要自行调参来提高效果。
优化建议:

  • 修改模型 现在是线性模型可以尝试修改更为复杂的
    对于nlp项目更加友好的(具体的我也不是很清楚)
  • 调整学习率来调整我们最好效果的查找
  • 可以通过对已有模型进一步训练得到较好的效果
  • ……

数据集地址

https://aistudio.baidu.com/aistudio/datasetdetail/75812

任务描述

基于THUCNews数据集的文本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档,参赛者需要根据新闻标题的内容用算法来判断该新闻属于哪一类别

数据说明

THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。

已将训练集按照“标签ID+\t+标签+\t+原文标题”的格式抽取出来,可以直接根据新闻标题进行文本分类任务,希望答题者能够给出自己的解决方案。

训练集格式 标签ID+\t+标签+\t+原文标题 测试集格式 原文标题

提交答案

考试提交,需要提交模型代码项目版本结果文件。结果文件为TXT文件格式,命名为result.txt,文件内的字段需要按照指定格式写入。

1.每个类别的行数和测试集原始数据行数应一一对应,不可乱序

2.输出结果应检查是否为83599行数据,否则成绩无效

3.输出结果文件命名为result.txt,一行一个类别,样例如下:

···

游戏

财经

时政

股票

家居

科技

社会

房产

教育

星座

科技

股票

游戏

财经

时政

股票

家居

科技

社会

房产

教育

···

代码思路说明

根据题目可以知道这个是一个经典的nlp任务。
根据nlp任务处理的一般流程,我们需要进行以下几个步骤:

  • 数据处理并转换成词向量
  • 模型的搭建
  • 数据的训练
  • 模型读取并推理数据得到结果

那么话不多说我们开始!

数据集解压

! pip install -U paddlepaddle==2.0.1
! unzip -oq /home/aistudio/data/data75812/新闻文本标签分类.zip
import paddle
import numpy as np
import matplotlib.pyplot as plt
import paddle.nn as nn
import os
import numpy as npprint(paddle.__version__)  # 查看当前版本# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
# device = paddle.set_device('gpu')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-03-27 12:21:25,357 - INFO - generated new fontManager2.0.1

数据处理

首先我们考虑词向量的书写方式。
我们先制作词典(此处词典已经制作完成,我们直接读取就好了,词典制作过程会放在留言中)
我们把词典和我们的数据集进行对应,制作完成一个纯数字的对应码
得到对应码以后进行输出测试是否正确。
数据无误进行填充,把数据码用特殊标签进行替代完成数据长度相同的内容
检验数据长度

数据读取(字典、数据集)

# 字典读取
def get_dict_len(d_path):with open(d_path, 'r', encoding='utf-8') as f:line = eval(f.readlines()[0])return lineword_dict = get_dict_len('新闻文本标签分类/dict.txt')
# 训练集和验证集读取
set = []
def dataset(datapath):  # 数据集读取代码with open(datapath)as f:for i in f.readlines():data = []dataset = i[:i.rfind('\t')].split(',')  # 获取文字内容dataset = np.array(dataset)data.append(dataset)label = np.array(i[i.rfind('\t')+1:-1])  # 获取标签data.append(label)set.append(data)return settrain_dataset = dataset('新闻文本标签分类/Train_IDs.txt')
val_dataset = dataset('新闻文本标签分类/Val_IDs.txt')

数据初始化


定义一些需要的值

# 初始数据准备
vocab_size = len(word_dict) + 1  # 字典长度加1
print(vocab_size)
emb_size = 256  # 神经网络长度
seq_len = 30  # 数据集长度(需要扩充的长度)
batch_size = 32  # 批处理大小
epochs = 2  # 训练轮数
pad_id = word_dict['<unk>']  # 空的填充内容值nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]# 生成句子列表(数据码生成文本)
def ids_to_str(ids):# print(ids)words = []for k in ids:w = list(word_dict)[eval(k)]words.append(w if isinstance(w, str) else w.decode('ASCII'))return " ".join(words)
5308

数据查看


查看数据是否正确如有异常及时修改

# 查看数据内容
for i in  train_dataset:sent = i[0]label = int(i[1])print('sentence list id is:', sent)  # 数据内容print('sentence label id is:', label)  # 对应标签print('--------------------------')  # 分隔线print('sentence list is: ', ids_to_str(sent))  # 转换后的数据print('sentence label is: ', nu[label])  # 转换后的标签break
sentence list id is: ['2976' '385' '2050' '3757' '1147' '3296' '1585' '688' '1180' '2608''4280' '1887']
sentence label id is: 0
--------------------------
sentence list is:  上 证 5 0 E T F 净 申 购 突 增
sentence label is:  财经

数据扩充


把数据扩充成一样的长度

# 数据扩充并查看
def create_padded_dataset(dataset):padded_sents = []labels = []for batch_id, data in enumerate(dataset):  # 读取数据sent, label = data[0], data[1]  # 标签和数据拆分padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')  # 数据拼接# print(padded_sent)padded_sents.append(padded_sent)  # 写入数据labels.append(label)  # 写入标签# print(padded_sents)return np.array(padded_sents), np.array(labels).astype('int64')  # 转换成数组并返回# 对train、val数据进行实例化
train_sents, train_labels = create_padded_dataset(train_dataset)  # 实例化训练集
val_sents, val_labels = create_padded_dataset(val_dataset)  # 实例化测试集
train_labels = train_labels.reshape(832475,1)  # 标签数据大小转换
val_labels = val_labels.reshape(832475,1)
# 查看数据大小及举例内容
print(train_sents.shape)
print(train_labels.shape)
print(val_sents.shape)
print(val_labels.shape)
(832475, 30)
(832475, 1)
(832475, 30)
(832475, 1)

数据封装


通过继承paddle.io.Dataset类,把数据封装然后生成可以训练的数据格式

# 继承paddle.io.Dataset对数据进行处理
class IMDBDataset(paddle.io.Dataset):'''继承paddle.io.Dataset类进行封装数据'''def __init__(self, sents, labels):# 数据读取self.sents = sentsself.labels = labelsdef __getitem__(self, index):# 数据处理data = self.sents[index]label = self.labels[index]return data, labeldef __len__(self):# 返回大小数据return len(self.sents)# 数据实例化
train_dataset = IMDBDataset(train_sents, train_labels)
val_dataset = IMDBDataset(val_sents, val_labels)# 封装成生成器
train_loader = paddle.io.DataLoader(train_dataset, return_list=True,shuffle=True, batch_size=batch_size, drop_last=True)
val_loader = paddle.io.DataLoader(val_dataset, return_list=True,shuffle=True, batch_size=batch_size, drop_last=True)
# 查看生成器内的数据内容及大小
for i in train_loader:print(i)break
for j in val_loader:print(j)break
[Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True,[[4041, 4370, 3449, 3536, 103 , 2896, 4133, 312 , 1974, 3933, 2380, 805 , 3956, 4805, 3129, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1440, 3740, 1169, 2663, 4401, 4591, 4874, 2734, 989 , 1980, 5016, 450 , 335 , 1562, 2543, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[580 , 3844, 3513, 1231, 4111, 1894, 737 , 1318, 3536, 4805, 3956, 4075, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2573, 536 , 1230, 3757, 610 , 2018, 1974, 39  , 1629, 121 , 4625, 294 , 450 , 1991, 3149, 4389, 1146, 1736, 588 , 3388, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4829, 4419, 3415, 1230, 4910, 3814, 1876, 3509, 1592, 5059, 2207, 2139, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1546, 1221, 1117, 4386, 3449, 1562, 2088, 4770, 1299, 4500, 41  , 2976, 725 , 1006, 2053, 897 , 2315, 3786, 2559, 828 , 3682, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2185, 4673, 1546, 2991, 1120, 5025, 782 , 5025, 1674, 3717, 1006, 2099, 4807, 78  , 4749, 1932, 5283, 1375, 4725, 3185, 2358, 2100, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2140, 4935, 3388, 278 , 3287, 4059, 775 , 1304, 4315, 698 , 3375, 3966, 3980, 1472, 1472, 2140, 4935, 3388, 5303, 939 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[5072, 2886, 4647, 3957, 5276, 2139, 4646, 5053, 4073, 4954, 1006, 4038, 2896, 3886, 756 , 4289, 2700, 4242, 4954, 2018, 2336, 2412, 2764, 4711, 5306, 5306, 5306, 5306, 5306, 5306],[1546, 1231, 1230, 385 , 4774, 5269, 939 , 2845, 1147, 2358, 3947, 4774, 872 , 1592, 2896, 123 , 5059, 1177, 3947, 4191, 4841, 754 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1180, 2646, 2155, 2776, 2886, 1257, 2302, 2748, 39  , 1230, 478 , 1006, 1425, 2263, 1278, 5078, 959 , 5102, 4578, 671 , 3430, 4954, 4910, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2891, 5257, 4426, 4932, 189 , 1695, 1347, 1724, 4328, 3344, 1688, 3449, 5115, 379 , 1347, 2244, 5216, 3070, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[67  , 2788, 2873, 898 , 4207, 1347, 12  , 372 , 1737, 1006, 3468, 383 , 1836, 5115, 4608, 4790, 1620, 760 , 3313, 2244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2099, 4807, 3379, 200 , 3933, 472 , 4415, 312 , 2078, 3222, 44  , 3222, 3924, 2373, 3398, 643 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2311, 3967, 720 , 2014, 2873, 311 , 4346, 2961, 4401, 725 , 1425, 1006, 1505, 3430, 4647, 926 , 4554, 4702, 4246, 2358, 3115, 5279, 123 , 1230, 679 , 5306, 5306, 5306, 5306, 5306],[1521, 2571, 1079, 4554, 1070, 534 , 2088, 2140, 5229, 1425, 3242, 846 , 3933, 3714, 99  , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2916, 123 , 1844, 5059, 123 , 1747, 3040, 1006, 5205, 1688, 1347, 601 , 3041, 3144, 3269, 4059, 2986, 4863, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3888, 2153, 4813, 3053, 1741, 1648, 2757, 1177, 2033, 2991, 5283, 123 , 2779, 2651, 1053, 1522, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4444, 5283, 1138, 3114, 3890, 3489, 1028, 3717, 936 , 389 , 2886, 2031, 316 , 3187, 2031, 2623, 643 , 4911, 3468, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3740, 2925, 3023, 2851, 4389, 3092, 3576, 725 , 1736, 2300, 3114, 1006, 2122, 1076, 3973, 3092, 3951, 2664, 1059, 3440, 415 , 3099, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1441, 312 , 134 , 4697, 1896, 1449, 3973, 4955, 3449, 1498, 1199, 2032, 2359, 4822, 1006, 4883, 4389, 4038, 4552, 4509, 2347, 690 , 1094, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2777, 422 , 1902, 2428, 621 , 3313, 3973, 5014, 5140, 3086, 4822, 1006, 3809, 3305, 3343, 5161, 1230, 1995, 3684, 954 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[813 , 903 , 4554, 3449, 1195, 3790, 4067, 1932, 2347, 3082, 4625, 2061, 3191, 992 , 1006, 1819, 3040, 4650, 1395, 729 , 5125, 5202, 2939, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3952, 3493, 385 , 225 , 3449, 1613, 4822, 3534, 3191, 2896, 3927, 698 , 3375, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1230, 831 , 1347, 2244, 1588, 3813, 2044, 3094, 1076, 4626, 1006, 1231, 1230, 3853, 4366, 2511, 2605, 3726, 5303, 939 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[5052, 2293, 3449, 3446, 1094, 2976, 4922, 2099, 1221, 4034, 1290, 3323, 3430, 3099, 4109, 4579, 1006, 1713, 3058, 4370, 1613, 4191, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1359, 4922, 2748, 3933, 2099, 397 , 2858, 1006, 4438, 221 , 611 , 4159, 2642, 939 , 4784, 664 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4554, 1667, 477 , 2891, 1819, 2354, 1819, 3040, 1006, 2873, 898 , 3740, 1408, 2176, 3371, 123 , 5151, 2886, 3040, 1275, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4746, 3242, 5010, 3430, 2401, 4426, 4373, 1695, 2776, 775 , 1006, 1502, 3952, 2428, 1935, 3687, 809 , 416 , 1503, 4500, 1854, 2352, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2351, 3287, 3813, 2032, 4554, 1519, 1655, 4038, 3951, 2958, 2886, 2140, 1006, 4246, 4536, 3449, 1476, 2572, 4207, 4401, 1505, 2953, 3468, 377 , 5306, 5306, 5306, 5306, 5306, 5306],[3712, 3583, 3973, 2312, 4426, 3305, 2979, 1897, 3513, 4059, 1695, 1006, 5293, 4382, 2199, 1076, 4412, 3559, 1215, 2640, 1343, 4785, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2830, 2567, 1472, 134 , 3040, 1275, 3951, 377 , 420 , 1753, 1598, 690 , 3682, 4500, 1006, 3135, 3853, 4862, 3253, 377 , 2263, 5105, 3060, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True,[[3 ],[6 ],[3 ],[5 ],[6 ],[3 ],[6 ],[2 ],[10],[3 ],[10],[6 ],[13],[6 ],[10],[6 ],[6 ],[4 ],[9 ],[10],[10],[13],[10],[3 ],[0 ],[3 ],[3 ],[13],[13],[10],[13],[10]])]
[Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True,[[2607, 5278, 1979, 2932, 40  , 2813, 2361, 3114, 4111, 3099, 1221, 103 , 2079, 3951, 2050, 3757, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2050, 2751, 3403, 1214, 516 , 1006, 4059, 2125, 2380, 233 , 1521, 805 , 366 , 2336, 2176, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2312, 487 , 2185, 4832, 4426, 2099, 1811, 1695, 1413, 4813, 3053, 3222, 4523, 3820, 2143, 1020, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3740, 4334, 377 , 1299, 4062, 4442, 536 , 3487, 3398, 4863, 1850, 4480, 1006, 2896, 4673, 2776, 1230, 3114, 3786, 4442, 3507, 1902, 2428, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1445, 4075, 1006, 610 , 805 , 3757, 3634, 2453, 1521, 736 , 1661, 1394, 4874, 3822, 1006, 421 , 3424, 3296, 610 , 610 , 3757, 316 , 4863, 3702, 2192, 5306, 5306, 5306, 5306, 5306],[2185, 2685, 4863, 5257, 3430, 2813, 2233, 684 , 846 , 892 , 1006, 3593, 3966, 3951, 4343, 2079, 892 , 4352, 4242, 3091, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3131, 1809, 2052, 4359, 3449, 1199, 2401, 1441, 2768, 4073, 1724, 4191, 1301, 3956, 3757, 2050, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[5032, 410 , 4835, 3449, 2099, 44  , 989 , 4073, 1724, 4191, 1521, 1521, 3642, 2751, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2050, 3757, 3880, 4945, 2515, 1112, 4224, 1282, 3379, 4477, 834 , 2013, 4874, 3823, 617 , 1090, 4060, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3886, 2843, 2412, 1722, 1230, 3092, 4197, 1006, 699 , 1839, 380 , 1834, 1521, 3757, 1631, 4237, 518 , 3813, 2768, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1440, 3740, 3520, 2832, 3888, 2886, 1993, 3952, 2427, 1215, 2550, 4248, 4328, 4099, 5103, 2337, 3468, 4456, 3191, 4062, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[802 , 97  , 1876, 2768, 4191, 4785, 1318, 3991, 1006, 3165, 4191, 3509, 1318, 4504, 736 , 3757, 3757, 3757, 3757, 3375, 5019, 4959, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2233, 4370, 1347, 726 , 3886, 3142, 3259, 260 , 1445, 746 , 3238, 1025, 332 , 993 , 1006, 1301, 1661, 2845, 1836, 5115, 3738, 2199, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[5267, 4953, 1472, 1876, 2873, 3951, 377 , 4841, 754 , 125 , 3224, 1006, 3951, 3967, 2983, 2886, 4038, 5135, 684 , 123 , 1521, 1301, 2846, 389 , 4841, 5306, 5306, 5306, 5306, 5306],[1993, 1837, 5281, 3992, 1425, 3740, 224 , 804 , 3534, 3191, 2099, 4807, 3735, 5067, 1006, 4449, 2375, 2375, 4945, 2515, 2436, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[224 , 4382, 3379, 200 , 3449, 1230, 3996, 805 , 141 , 1006, 3379, 200 , 4576, 2680, 3430, 3042, 1081, 3537, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2820, 682 , 2825, 2759, 1230, 294 , 4389, 3069, 3355, 2896, 1215, 2825, 4222, 3244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4274, 3740, 3413, 3449, 134 , 377 , 603 , 3886, 2873, 123 , 4289, 3020, 1230, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[1293, 3714, 40  , 1472, 1094, 1440, 1669, 3966, 2756, 4432, 1521, 4591, 1094, 4591, 2052, 1006, 4951, 1418, 4019, 1425, 3740, 4775, 1839, 3430, 738 , 5306, 5306, 5306, 5306, 5306],[1164, 2453, 1185, 4162, 3430, 1546, 3740, 3398, 2052, 3559, 1221, 2050, 3956, 3757, 805 , 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2312, 536 , 1747, 3164, 2986, 542 , 3023, 3907, 1006, 4456, 4009, 3296, 3634, 1521, 3757, 5059, 736 , 736 , 3757, 3757, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[3142, 4886, 3430, 4954, 5177, 4242, 4382, 3952, 4931, 795 , 1006, 2099, 2886, 4651, 1562, 2986, 2155, 1521, 4591, 3966, 601 , 3041, 2151, 377 , 5306, 5306, 5306, 5306, 5306, 5306],[3400, 872 , 1893, 3016, 3933, 2263, 2781, 3114, 692 , 3222, 1620, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2176, 2830, 3449, 3668, 3131, 3402, 2727, 224 , 264 , 4370, 4389, 1318, 1641, 2932, 1940, 4805, 2886, 4207, 4225, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2549, 3101, 2099, 690 , 4111, 3682, 3537, 534 , 4167, 3137, 4954, 1006, 1785, 3869, 823 , 3924, 3473, 3881, 927 , 730 , 592 , 476 , 3207, 241 , 5306, 5306, 5306, 5306, 5306, 5306],[4227, 1562, 4027, 4954, 1521, 610 , 375 , 3889, 2896, 2239, 4370, 4141, 3000, 56  , 1006, 4697, 200 , 269 , 926 , 1413, 4540, 5238, 1017, 3468, 2014, 964 , 5306, 5306, 5306, 5306],[2358, 1025, 1708, 993 , 332 , 1006, 1862, 1006, 2358, 1025, 3956, 2079, 1709, 720 , 3676, 4050, 3357, 1472, 2941, 2254, 2412, 1029, 3222, 1725, 1028, 3165, 5306, 5306, 5306, 5306],[2820, 682 , 3191, 3440, 1146, 3174, 4328, 2982, 2825, 2759, 1117, 3069, 3355, 617 , 2813, 2742, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[2612, 3069, 1214, 3951, 1521, 3956, 4805, 3115, 1314, 2050, 3757, 3757, 366 , 1006, 1534, 2401, 5202, 1521, 4805, 366 , 3157, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306],[4057, 62  , 1765, 531 , 1991, 3149, 5269, 736 , 3757, 1521, 736 , 1991, 3149, 4389, 2018, 4389, 2253, 1694, 4073, 1200, 5116, 4073, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4746, 4283, 5295, 3449, 2099, 1794, 376 , 2040, 1663, 3564, 3187, 2986, 1006, 4111, 2896, 690 , 1117, 2776, 4500, 2078, 2040, 698 , 1214, 5306, 5306, 5306, 5306, 5306, 5306, 5306],[4289, 2166, 698 , 2100, 1006, 1343, 1681, 1094, 4863, 123 , 5162, 384 , 61  , 2380, 1645, 3388, 2336, 736 , 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True,[[6 ],[8 ],[4 ],[10],[6 ],[5 ],[3 ],[3 ],[6 ],[6 ],[13],[0 ],[12],[10],[13],[3 ],[7 ],[9 ],[10],[3 ],[6 ],[10],[4 ],[9 ],[10],[10],[3 ],[7 ],[7 ],[5 ],[10],[8 ]])]

网络定义


定义网络情况,用于训练,这一块是提高成绩的关键之一

# 定义网络
class MyNet(paddle.nn.Layer):def __init__(self):super(MyNet, self).__init__() self.emb = paddle.nn.Embedding(vocab_size, emb_size)  # 嵌入层用于自动构造一个二维embedding矩阵self.fc = paddle.nn.Linear(in_features=emb_size, out_features=96)  # 线性变换层self.fc1 = paddle.nn.Linear(in_features=96, out_features=14)  # 分类器self.dropout = paddle.nn.Dropout(0.5)  # 正则化def forward(self, x):x = self.emb(x)x = paddle.mean(x, axis=1)  # 获取平均值x = self.dropout(x)x = self.fc(x)x = self.dropout(x)x = self.fc1(x)return x
# 画图
def draw_process(title,color,iters,data,label):plt.title(title, fontsize=24)  # 标题plt.xlabel("iter", fontsize=20)  # x轴plt.ylabel(label, fontsize=20)  # y轴plt.plot(iters, data,color=color,label=label)   # 画图plt.legend()plt.grid()plt.show()

模型训练


训练的重要环节,可以调节学习率,优化器等,有可能有奇效

# 训练模型
def train(model):model.train()opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())  # 优化器学习率等# 初始值设置steps = 0Iters, total_loss, total_acc = [], [], []for epoch in range(epochs):  # 训练循环for batch_id, data in enumerate(train_loader):  # 数据循环steps += 1sent = data[0]  # 获取数据label = data[1]  # 获取标签logits = model(sent)  # 输入数据loss = paddle.nn.functional.cross_entropy(logits, label)  # loss获取acc = paddle.metric.accuracy(logits, label)  # acc获取if batch_id % 500 == 0:  # 每500次输出一次结果Iters.append(steps)  # 保存训练轮数total_loss.append(loss.numpy()[0])  # 保存losstotal_acc.append(acc.numpy()[0])  # 保存accprint("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))  # 输出结果# 数据更新loss.backward()  opt.step()  opt.clear_grad()  # 每一个epochs进行一次评估model.eval()accuracies = []losses = []for batch_id, data in enumerate(val_loader):  # 数据循环读取sent = data[0]  # 训练内容读取label = data[1]  # 标签读取logits = model(sent)  # 训练数据loss = paddle.nn.functional.cross_entropy(logits, label)  # loss获取acc = paddle.metric.accuracy(logits, label)  # acc获取accuracies.append(acc.numpy())  # 添加数据losses.append(loss.numpy())  avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)  # 获取loss、acc平均值print("[validation] accuracy: {}, loss: {}".format(avg_acc, avg_loss))  # 输出值model.train()paddle.save(model.state_dict(),str(epoch)+"_model_final.pdparams")  # 保存训练文件draw_process("trainning loss","red",Iters,total_loss,"trainning loss")  # 画处loss图draw_process("trainning acc","green",Iters,total_acc,"trainning acc")  # 画出caa图model = MyNet()  # 模型实例化
train(model)  # 开始训练
epoch: 0, batch_id: 0, loss is: [2.6477456]
epoch: 0, batch_id: 500, loss is: [1.8056118]
epoch: 0, batch_id: 1000, loss is: [1.1092072]
epoch: 0, batch_id: 1500, loss is: [1.0716103]
epoch: 0, batch_id: 2000, loss is: [0.6794955]
epoch: 0, batch_id: 2500, loss is: [0.54738545]
epoch: 0, batch_id: 3000, loss is: [0.9065808]
epoch: 0, batch_id: 3500, loss is: [0.63474274]
epoch: 0, batch_id: 4000, loss is: [0.68158776]
epoch: 0, batch_id: 4500, loss is: [1.0516238]
epoch: 0, batch_id: 5000, loss is: [0.9118046]
epoch: 0, batch_id: 5500, loss is: [0.65075576]
epoch: 0, batch_id: 6000, loss is: [0.5605841]
epoch: 0, batch_id: 6500, loss is: [0.56175774]
epoch: 0, batch_id: 7000, loss is: [0.95122683]
epoch: 0, batch_id: 7500, loss is: [0.38649452]
epoch: 0, batch_id: 8000, loss is: [0.2205698]
epoch: 0, batch_id: 8500, loss is: [0.40474647]
epoch: 0, batch_id: 9000, loss is: [0.5931748]
epoch: 0, batch_id: 9500, loss is: [0.3922717]
epoch: 0, batch_id: 10000, loss is: [0.6130478]
epoch: 0, batch_id: 10500, loss is: [0.5300909]
epoch: 0, batch_id: 11000, loss is: [0.6114788]
epoch: 0, batch_id: 11500, loss is: [0.24966809]
epoch: 0, batch_id: 12000, loss is: [0.45669073]
epoch: 0, batch_id: 12500, loss is: [0.29746443]
epoch: 0, batch_id: 13000, loss is: [0.6775298]
epoch: 0, batch_id: 13500, loss is: [0.8836371]
epoch: 0, batch_id: 14000, loss is: [0.27501673]
epoch: 0, batch_id: 14500, loss is: [0.46843478]
epoch: 0, batch_id: 15000, loss is: [0.49367175]
epoch: 0, batch_id: 15500, loss is: [0.500063]
epoch: 0, batch_id: 16000, loss is: [0.31290954]
epoch: 0, batch_id: 16500, loss is: [0.30774388]
epoch: 0, batch_id: 17000, loss is: [0.21738727]
epoch: 0, batch_id: 17500, loss is: [0.2860858]
epoch: 0, batch_id: 18000, loss is: [0.2766972]
epoch: 0, batch_id: 18500, loss is: [0.36017033]
epoch: 0, batch_id: 19000, loss is: [0.43986273]
epoch: 0, batch_id: 19500, loss is: [0.4210134]
epoch: 0, batch_id: 20000, loss is: [0.579644]
epoch: 0, batch_id: 20500, loss is: [0.23016676]
epoch: 0, batch_id: 21000, loss is: [0.21913218]
epoch: 0, batch_id: 21500, loss is: [0.18669227]
epoch: 0, batch_id: 22000, loss is: [0.31480896]
epoch: 0, batch_id: 22500, loss is: [0.37621552]
epoch: 0, batch_id: 23000, loss is: [0.54980826]
epoch: 0, batch_id: 23500, loss is: [0.6016808]
epoch: 0, batch_id: 24000, loss is: [0.25056183]
epoch: 0, batch_id: 24500, loss is: [0.2916811]
epoch: 0, batch_id: 25000, loss is: [0.33430776]
epoch: 0, batch_id: 25500, loss is: [0.74600095]
epoch: 0, batch_id: 26000, loss is: [0.35165167]
[validation] accuracy: 0.884321928024292, loss: 0.3713749647140503
epoch: 1, batch_id: 0, loss is: [0.47405708]
epoch: 1, batch_id: 500, loss is: [0.4443894]
epoch: 1, batch_id: 1000, loss is: [0.35416052]
epoch: 1, batch_id: 1500, loss is: [0.3004715]
epoch: 1, batch_id: 2000, loss is: [0.59477925]
epoch: 1, batch_id: 2500, loss is: [0.5639044]
epoch: 1, batch_id: 3000, loss is: [0.40286714]
epoch: 1, batch_id: 3500, loss is: [0.5387965]
epoch: 1, batch_id: 4000, loss is: [0.11766122]
epoch: 1, batch_id: 4500, loss is: [0.68849707]
epoch: 1, batch_id: 5000, loss is: [0.83928466]
epoch: 1, batch_id: 5500, loss is: [0.2867105]
epoch: 1, batch_id: 6000, loss is: [0.20924558]
epoch: 1, batch_id: 6500, loss is: [0.5582311]
epoch: 1, batch_id: 7000, loss is: [0.63174886]
epoch: 1, batch_id: 7500, loss is: [0.318484]
epoch: 1, batch_id: 8000, loss is: [0.5406461]
epoch: 1, batch_id: 8500, loss is: [0.4790561]
epoch: 1, batch_id: 9000, loss is: [0.52266514]
epoch: 1, batch_id: 9500, loss is: [0.51126254]
epoch: 1, batch_id: 10000, loss is: [0.27308795]
epoch: 1, batch_id: 10500, loss is: [0.22041513]
epoch: 1, batch_id: 11000, loss is: [0.32234907]
epoch: 1, batch_id: 11500, loss is: [0.6857507]
epoch: 1, batch_id: 12000, loss is: [0.40997463]
epoch: 1, batch_id: 12500, loss is: [0.53966033]
epoch: 1, batch_id: 13000, loss is: [0.2620927]
epoch: 1, batch_id: 13500, loss is: [0.21417136]
epoch: 1, batch_id: 14000, loss is: [0.5232475]
epoch: 1, batch_id: 14500, loss is: [0.37579858]
epoch: 1, batch_id: 15000, loss is: [0.3611152]
epoch: 1, batch_id: 15500, loss is: [0.336707]
epoch: 1, batch_id: 16000, loss is: [0.2795578]
epoch: 1, batch_id: 16500, loss is: [0.54298353]
epoch: 1, batch_id: 17000, loss is: [0.26425135]
epoch: 1, batch_id: 17500, loss is: [0.52595145]
epoch: 1, batch_id: 18000, loss is: [0.24938256]
epoch: 1, batch_id: 18500, loss is: [0.30653632]
epoch: 1, batch_id: 19000, loss is: [0.58400965]
epoch: 1, batch_id: 19500, loss is: [0.18243803]
epoch: 1, batch_id: 20000, loss is: [0.28917578]
epoch: 1, batch_id: 20500, loss is: [1.0765818]
epoch: 1, batch_id: 21000, loss is: [0.32550114]
epoch: 1, batch_id: 21500, loss is: [0.16792971]
epoch: 1, batch_id: 22000, loss is: [0.65214527]
epoch: 1, batch_id: 22500, loss is: [0.58119446]
epoch: 1, batch_id: 23000, loss is: [0.43643892]
epoch: 1, batch_id: 23500, loss is: [0.47376677]
epoch: 1, batch_id: 24000, loss is: [0.3279624]
epoch: 1, batch_id: 24500, loss is: [0.50899947]
epoch: 1, batch_id: 25000, loss is: [0.61989105]
epoch: 1, batch_id: 25500, loss is: [0.42433214]
epoch: 1, batch_id: 26000, loss is: [0.26673254]
[validation] accuracy: 0.8882260322570801, loss: 0.35311153531074524/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YjXXeU9C-1618337714702)(output_26_2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Xub3O8RP-1618337714703)(output_26_3.png)]

推理数据读取

# 比赛数据读取
set = []
def dataset(datapath):with open(datapath)as f:  # 读取文件for i in f.readlines():  # 逐行读取数据dataset = np.array(i.split(','))  # 分割数据set.append(dataset)  # 存入数据return set# 比赛数据扩充
def create_padded_dataset(dataset):padded_sents = []labels = []for batch_id, data in enumerate(dataset):  # 循环# print(data)sent = data  # 读取数据padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')  # 拼接填充# print(padded_sent)padded_sents.append(padded_sent)  # 输入数据# print(padded_sents)return np.array(padded_sents)  # 转换成数组并返回test_data = dataset('新闻文本标签分类/Test_IDs.txt')  # 读取数据
# print()
# 对train、val数据进行实例化
test_data = create_padded_dataset(test_data)  # 数据填充# 查看数据大小及举例内容
print(test_data)
[[4057 1902 1475 ... 5306 5306 5306][2805 5242 3593 ... 5306 5306 5306][1836 3222 4641 ... 5306 5306 5306]...[4838 1202 1490 ... 5306 5306 5306][ 805 3757 3757 ... 5306 5306 5306][2805 5242 3593 ... 5306 5306 5306]]

开始推理


这里可以选择效果好的模型然后进行预测

nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]  # 标签列表# 导入模型
model_state_dict = paddle.load('0_model_final.pdparams')  # 模型读取
model = MyNet()  # 读取网络
model.set_state_dict(model_state_dict)
model.eval()
# print(type(test_data[0]))
count = 0  # 初始值
with open('./result.txt', 'w', encoding='utf-8') as f_train:  # 生成文件for batch_id, data in enumerate(test_data):  # 循环数据results = model(paddle.to_tensor(data.reshape(30,1)))  # 开始训练for probs in results:# 映射分类labelidx = np.argmax(probs)  # 获取结果值labels = nu[idx]  # 通过结果值获取标签f_train.write(labels+"\n")  # 写入数据count +=1breakif count%500==0:  # 查看推理情况print(count)print(count)

效果不一定好,但是可以跑通 ,如果有其他的需求可以联系我(留言或群里面at我)我也会进一步改进项目
感谢大家的支持!

作者简介

作者:三岁
经历:自学python,现在混迹于paddle社区,希望和大家一起从基础走起,一起学习Paddle
csdn地址:https://blog.csdn.net/weixin_45623093/article/list/3
我在AI Studio上获得钻石等级,点亮7个徽章,来互关呀~ https://aistudio.baidu.com/aistudio/personalcenter/thirdview/284366

传说中的飞桨社区最菜代码人,让我们一起努力!
记住:三岁出品必是精品 (不要脸系列

Paddle2.0实现中文新闻文本标题分类相关推荐

  1. 中文新闻文本标题分类(基于飞桨、Text CNN)

    目录 一.设计方案概述 二.具体实现 三.结果及分析 四.总结 一.设计方案概述 主要网络模型设计: 设计所使用网络模型为TextCNN,由于其本身就适用于短中句子,在标题分类这一方面应该能发挥其优势 ...

  2. 2021-4月Python 机器学习——中文新闻文本标题分类

    试题说明 试题说明 任务描述 基于THUCNews数据集的文本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档,参赛者需要根据新闻 ...

  3. 今日头条中文新闻文本(多层)分类数据集(NLP/文本分类)

    这是另一个数据集的加强版,为多级分类,分类更全(含1000+多级分类),量更大. 数据来源: 今日头条客户端 文本多层分类的概念见下图 数据格式: 1000866069|,|tip,news|,|[互 ...

  4. [Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  5. [Pytorch系列-60]:循环神经网络 - 中文新闻文本分类详解-2-LSTM网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  6. 基于 LSTM-Attention 的中文新闻文本分类

    1.摘 要 经典的 LSTM 分类模型,一种是利用 LSTM 最后时刻的输出作为高一级的表示,而另一种是将所有时刻的LSTM 输出求平均作为高一级的表示.这两种表示都存在一定的缺陷,第一种缺失了前面的 ...

  7. 基于BERT-PGN模型的中文新闻文本自动摘要生成——文本摘要生成(论文研读)

    基于BERT-PGN模型的中文新闻文本自动摘要生成(2020.07.08) 基于BERT-PGN模型的中文新闻文本自动摘要生成(2020.07.08) 摘要: 0 引言 相关研究 2 BERT-PGN ...

  8. 基于BERT-PGN模型的中文新闻文本自动摘要生成

    论文创新点 1.将BERT与指针生成网络(PGN)相结合,提出了一种面向中文新闻文本的生成式摘要模型,实现快速阅读: 2. 结合多维语义特征的BERT-PGN模型对摘要原文的理解更加充分,生成的摘要内 ...

  9. 基于神经网络语言模型的中文新闻文本聚类算法

    一.新闻文本集  其中  通过TF-IDF排序 中的词(由大到小),选择其中的 t 个词作为关键字,,是对应关键字的TF-IDF值. 二.神经网络语言模型 输入:该词的上下文中相邻的几个词向量(词袋模 ...

最新文章

  1. 9.65 最长上升子序列
  2. 如何保障消息中间件 100% 消息投递成功?如何保证消息幂等性?
  3. python 随机生成汉字的三种方法
  4. pandas处理日期的几种常用方法
  5. php查看运行时间和内存,php 统计时间和内存的使用情况
  6. mongodb空间查询java,java查看mongodb集合表空间大小
  7. 如何利用 Android 自定义控件实现炫酷的动画?|CSDN 博文精选
  8. 《信息安全系统设计基础》第三周学习总结
  9. cass二次开发vba和lisp_CAD二次开发LISP视频_小懒人CAD工具箱_CAD插件_CASS插件_LISP代码...
  10. 年度Sweb绩效考评表
  11. caffe从秃头到入门 /usr/bin/ruby -e “$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master
  12. ceph web监控管理平台calamari
  13. html中文本框的透明度,jQuery实现textarea文本框半透明文本提示效果
  14. Unity报错: Broken text PPtr in file(xxx). Local file identifier (xxx) doesn‘t exist
  15. 挣五千花一万,大气者成大器
  16. WPF 使用 SharpDx 异步渲染
  17. 模块和包管理工具npm
  18. 计算机毕业设计Java智能家电商城(系统+源码+mysql数据库+lw文档)
  19. 利用这10个App管理自己的时间,让生活变得井井有条
  20. 1.4链界观区块链资讯

热门文章

  1. FPGA基础入门【1】Vivado官方免费版安装
  2. oracle免费版本下载地址,Oracle各版本下载地址和方法
  3. android saf写sd卡,使用SAF(存储访问框架)的Android SD卡写权限
  4. Android9.0中应用如何通过SAF框架写入外置SD卡
  5. 解决Solaris应用程序开发内存泄漏问题 (1)
  6. 编程之类的文案_有什么有逼格的四字文案?
  7. nodejs+ffmpeg视频转码
  8. linux 开机自启动 Tomcat
  9. UE4-(反射)平面反射
  10. qt creator使用vcpkg