文章目录

  • 1 数据准备
  • 2 数据预处理
  • 3 交叉验证&特征提取
  • 4 模型训练
  • 5 评估与总结

1 数据准备

数据集格式:

import numpy as np
import pandas as pd
import time
import jieba
import re
import string
import pickle
from tqdm import tqdm
from zhon.hanzi import punctuation
from collections import Counter
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from lightgbm import LGBMClassifier# Pandas设置
pd.set_option("display.max_columns", None)  # 设置显示完整的列
pd.set_option("display.max_rows", None)  # 设置显示完整的行
pd.set_option("display.expand_frame_repr", False)  # 设置不折叠数据
pd.set_option("display.max_colwidth", 100)  # 设置列的最大宽度# 中文停用词
STOPWORDS_ZH = '../data/stopwords_zh.txt'# 加载数据集
train_df = pd.read_json('../data/tnews/train.json', lines=True, nrows=20000)
train_df.info()
print(train_df.head(10))# 删除keywords列
del train_df['keywords']
print(train_df.tail())

2 数据预处理

思路:

  1. 去除无效字符(英文字符、数字、表情、中英文标点符号、空白)
  2. 中文分词(jieba)
  3. 去停用词
  4. 去低频词(阈值为20)
  5. 序列化列表
# 去除无用的字符
def clear_character(sentence):pattern1 = re.compile('[a-zA-Z0-9]')  # 英文字符和数字pattern2 = re.compile(u'[^\s1234567890::' + '\u4e00-\u9fa5]+')  # 表情和其他字符pattern3 = re.compile('[%s]+' % re.escape(punctuation + string.punctuation))  # 标点符号line1 = re.sub(pattern1, '', sentence)line2 = re.sub(pattern2, '', line1)line3 = re.sub(pattern3, '', line2)new_sentence = ''.join(line3.split())  # 去除空白return new_sentence# 预处理
def preprocessing(df, col_name):t1 = time.time()print('去除无用字符')df[col_name + '_processed'] = df[col_name].apply(clear_character)print(df[col_name + '_processed'].values)print('中文分词')cut_words = []for content in df[col_name + '_processed'].values:seg_list = jieba.lcut(content)cut_words.append(seg_list)print(cut_words[0])print('去停用词')with open(STOPWORDS_ZH, 'r', encoding='utf8') as f:stopwords = f.read().split(sep='\n')for seg_list in cut_words:for seg in seg_list:if seg in stopwords:seg_list.remove(seg)  # 删除print(cut_words[0])print('去低频词')min_threshold = 20word_list = []for seg_list in cut_words:word_list.extend(seg_list)counter = Counter(word_list)delete_list = []  # 要去除的停用词for k, v in counter.items():if v < min_threshold:delete_list.append(k)print(f'要去除掉低频词数量:{len(delete_list)}')for seg_list in tqdm(cut_words):for seg in seg_list:if seg in delete_list:seg_list.remove(seg)print(cut_words[0])print('序列化列表')with open('../data/cut_words.pkl', 'wb') as f:pickle.dump(cut_words, f)t2 = time.time()print(f'共耗时{t2 - t1}秒')

3 交叉验证&特征提取

思路:

  1. StratifiedKFold抽样(10份)
  2. 定义accuracy_list等,保存模型评价指标
  3. 特征提取(TfidfTransforme+CountVectorizer)
  4. 模型fit、predict、accuracy
  5. 打印模型accuracy等
# 交叉验证(skf+10)
def cross_validate(model, X, y):t1 = time.time()skf = StratifiedKFold(n_splits=10)accuracy_list, precision_list, recall_list = [], [], []for i, (train_id, test_id) in enumerate(skf.split(X, y)):print(f'{i + 1}/10')X_train, X_val, y_train, y_val = X[train_id], X[test_id], y[train_id], y[test_id]X_train = X_train.ravel()X_val = X_val.ravel()X_train = [" ".join(i) for i in X_train]X_val = [" ".join(i) for i in X_val]# 特征提取transformer = TfidfTransformer()vectorizer = CountVectorizer(analyzer='word')train_tfidf = transformer.fit_transform(vectorizer.fit_transform(X_train)).toarray()val_tfidf = transformer.transform(vectorizer.transform(X_val)).toarray()#         train_tfidf = train_tfidf.astype(np.float32)#         val_tfidf = val_tfidf.astype(np.float32)print(train_tfidf.shape)  # (48024, 26954)# 模型训练print('模型训练')model.fit(train_tfidf, y_train)y_predict = model.predict(val_tfidf)accuracy = accuracy_score(y_val, y_predict)precision = precision_score(y_val, y_predict, average='micro')recall = recall_score(y_val, y_predict, average='micro')#         f1 = f1_score(y_val, y_predict, average='micro')accuracy_list.append(accuracy)precision_list.append(precision)recall_list.append(recall)print(accuracy)print(precision)print(recall)t2 = time.time()print(f"Acc:{np.mean(accuracy_list)} | Pre:{np.mean(precision_list)} | Rec:{np.mean(recall_list)}")print(f'耗时{t2 - t1}秒')

4 模型训练

思路:

  1. 加载序列
  2. 构建模型列表
  3. 交叉验证
def machine_learning():# 加载序列with open('../data/cut_words.pkl', 'rb') as f:cut_words = pickle.load(f)train_df['sentence_processed'] = pd.Series(cut_words)X = train_df['sentence_processed']y = train_df['label']X = np.array(X).reshape(-1, 1)y = np.array(y).reshape(-1, 1)y = LabelEncoder().fit_transform(y)# 交叉验证# models = [XGBClassifier(gpu_id=0, tree_method='gpu_hist')]models = [LogisticRegression(max_iter=200), RidgeClassifier(), GaussianNB(), BernoulliNB(), MultinomialNB(), LGBMClassifier()]model_names = ['逻辑回归', '岭回归', '高斯模型', '伯努利模型', '多项式模型', 'LGB']for model, model_name in zip(models, model_names):print(f'Start {model_name}')cross_validate(model, X, y)if __name__ == '__main__':preprocessing(train_df, 'sentence')machine_learning()

5 评估与总结

(1) StratifiedKFold()的split()函数要求同时传入X和y,且对X和y有要求:(特征数,样本数),所以提前将X和y转化为ndarray格式,并reshape(-1, 1)

(2)CounterVectorizer()的fit_transform()函数最好传入[str1, str2, str3]格式的list,否则将报错
(3)model.fit()报错,百度是不能传稀疏矩阵之类的,将之转化为ndarray解决了
(4)样本量大,tfidf生成的特征也多,使用PCA降维后准确率反而降低,速度也降低了,应该是没有做特征选择
(5)以下分别是逻辑回归、岭回归分类、高斯模型、伯努利模型、多项式模型、LightGBM的控制台输出,数据有15个类,50000多条数据,选取了20000条数据进行训练,没有做进一步做特征工程,也没有调参,精度最高的LR回归ACC值只有43.6%

Acc:0.43665000000000004 | Pre:0.43665000000000004 | Rec:0.43665000000000004
耗时1418.719212770462秒Acc:0.4325 | Pre:0.4325 | Rec:0.4325
耗时435.4280045032501秒Acc:0.25575000000000003 | Pre:0.25575000000000003 | Rec:0.25575000000000003
耗时64.14511179924011秒Acc:0.30455 | Pre:0.30455 | Rec:0.30455
耗时31.419057846069336秒Acc:0.40785 | Pre:0.40785 | Rec:0.40785
耗时17.39114022254944秒Acc:0.3839 | Pre:0.3839 | Rec:0.3839
耗时42.139591455459595秒

(6)之前测试用10000条数据LR回归ACC值只有38%,而提高到20000有5%的提升,提高样本量应该能提升精度
(7)笔记本性能有限,后续考虑使用linux服务器进行训练,考虑使用pycuda或cupy加速,以及RNN、LSTM等神经网络模型

数据挖掘实战(6)——机器学习实现文本分类(今日头条tnews数据集)相关推荐

  1. 自然语言处理入门实战1:基于机器学习的文本分类

    基于机器学习的文本分类 配置文件 数据集 数据预处理 model 模型 主函数 预测 结果 参考代码 本文参考复旦大学自然语言处理入门练习,主要是实现用tensorflow实现基于logistic/s ...

  2. NLP(新闻文本分类)——基于机器学习的文本分类

    文本表示方法 在机器学习算法的训练过程中,假设给定NNN个样本,每个样本有MMM个特征,这样组成了N×MN×MN×M的样本矩阵,然后完成算法的训练和预测.同样的在计算机视觉中可以将图片的像素看作特征, ...

  3. 基于机器学习的文本分类

    基于机器学习的文本分类 机器学习模型 文本表示方法 Part1 One-hot Bag of Words N-gram TF-IDF 基于机器学习的文本分类 Count Vectors + Ridge ...

  4. 基于机器学习的文本分类!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李露,西北工业大学,Datawhale优秀学习者 据不完全统计,网 ...

  5. Task03——零基础入门NLP - 基于机器学习的文本分类

    学习目标 学会TF-IDF使用原理 使用sklearn的机器学习模型完成文本分类 文本表示方法 one-hot bag of words N-grams TF-IDF 基于机器学习的文本分类代码

  6. 基于统计概率和机器学习的文本分类技术

    基于统计概率和机器学习的文本分类技术 -- 社区产品机器审核机制 一.现状 目前,所在公司社区类产品(论坛.博客.百科)每天都会接收到大量的垃圾.灌水信息,高峰期16小时内(晚6点以后到第二天9点前) ...

  7. 【NLP】基于机器学习的文本分类!

    作者:李露,西北工业大学,Datawhale优秀学习者 据不完全统计,网民们平均每人每周收到的垃圾邮件高达10封左右.垃圾邮件浪费网络资源的同时,还消耗了我们大量的时间.大家对此深恶痛绝,于是识别垃圾 ...

  8. NLP-Beginner:自然语言处理入门练习----task 1基于机器学习的文本分类

    任务一:基于机器学习的文本分类 任务传送门 项目是在github上的,数据集需要在kaggle上下载,稍微有些麻烦. wang盘:http://链接:https://pan.baidu.com/s/1 ...

  9. NLP-Task1:基于机器学习的文本分类

    NLP-Task1:基于机器学习的文本分类 实现基于logistic/softmax regression的文本分类 数据集:Classify the sentiment of sentences f ...

最新文章

  1. cygwin和mingw编译软件的疑问
  2. BottomNavigationView+ViewPager+Fragment仿微信底部导航栏
  3. Centos-检查文件系统并尝试修复-fsck
  4. JedisPool.getResource()方法卡死的解决办法
  5. memset()详解
  6. LAMP_PHP配置
  7. CSS元素隐藏原理和效果小结
  8. mysql servlet登录验证_使用Servlet和jdbc创建用户登录验证
  9. 阻塞和非阻塞(串口自环测试失败原因定位)
  10. NumberFormat和DecimalFormat
  11. 活着,要有温暖的感觉
  12. 2021-07-08解决大部分lanzous蓝奏云链接无法打开
  13. 首届Starcoin Move黑客松源码分析——Atlaspad
  14. 【组合数学】递推方程 ( 常系数线性齐次递推方程 | 常系数、线性、齐次 概念说明 | 常系数线性齐次递推方程公式解法 | 特征根 | 通解 | 特解 )
  15. 基于SpringBoot的毕业论文管理系统的设计与实现(开题报告)
  16. 什么是云服务器,云服务器有哪些优势和特点?
  17. Eclipse 字体、字号的设置、最佳字体推荐
  18. P2P的资金托管方式 参考
  19. 电脑如何长截屏截图_您的意见很重要-截屏技术调查
  20. W5500开发笔记 | 02 - 使用W5500 Socket API 建立TCP服务端、TCP客户端

热门文章

  1. 2017年AI技术盘点:关键进展与趋势
  2. 手机关机不拔电池也能被定位吗?
  3. windows 环境下node开发环境搭配问题
  4. 转载】泡MM与GOF的23种模式(看着挺有意思)
  5. linux下mv命令参数详解,linux下的mv命令使用详解
  6. Unity3D基本入门及功能介绍
  7. 机器学习:李航-统计学习方法-代码实现
  8. win7创建任务计划:自动关机命令
  9. 双色球大乐透开奖查询软件
  10. C++ STL 中大根堆,小根堆的应用。