一、数据集介绍
中文文本分类数据集
数据来源:
今日头条客户端
数据格式:

6554695793956094477_!_110_!_news_military_!_「欧洲第一陆军」法兰西帝国的欧陆霸权_!_查理八世,布列塔尼,卡佩王朝,佛兰德斯,法国
6554855520291783175_!_110_!_news_military_!_以色列为巷战而研发的重型装甲运兵车,美军也租一辆进行作战测试_!_装甲运兵车,重型步兵战车,步兵战车,以色列,雌虎,M113,T-55
6525155156756005383_!_116_!_news_game_!_植物大战僵尸 僵尸治愈之旅 原来僵尸也会玩螳螂黄雀之计_!_黄雀之计,植物大战僵尸
6525662251456659971_!_116_!_news_game_!_我的世界vs火柴人番外篇:小橙闯入了MC世界_!_我的世界,番外篇,火柴人
6554666063026455044_!_116_!_news_game_!_贪吃蛇大作战:主宰全场,绿队最后的反超,我还是最佳MVP_!_贪吃蛇
6539508345470976515_!_116_!_news_game_!_古普象棋:铁滑车对战双正马,单兵入花心,荆轲刺秦王_!_古普象棋,铁滑车,花心,荆轲刺秦王
6545948385956856323_!_116_!_news_game_!_CF生存特训:“MK5-2”决赛圈,速度奔袭,敌人根本反应不过来!_!_决赛圈,小粉
6550234982705529358_!_116_!_news_game_!_LOL东北大鹌鹑:后期拉克丝大招近乎0CD,高爆发高伤害,恐怖!_!_大招,拉克丝

每行为一条数据,以_!_分割的个字段,从前往后分别是 新闻ID,分类code(见下文),分类名称(见下文),新闻字符串(仅含标题),新闻关键词

分类code与名称:

100 民生 故事 news_story
101 文化 文化 news_culture
102 娱乐 娱乐 news_entertainment
103 体育 体育 news_sports
104 财经 财经 news_finance
106 房产 房产 news_house
107 汽车 汽车 news_car
108 教育 教育 news_edu
109 科技 科技 news_tech
110 军事 军事 news_military
112 旅游 旅游 news_travel
113 国际 国际 news_world
114 证券 股票 stock
115 农业 三农 news_agriculture
116 电竞 游戏 news_game

数据规模:
共382688条,分布于15个分类中。
实验结果:
以80%、10%、10%做分割。

处理后的数据样式:

有哪些偏冷门的歌曲推荐? 0
“整容狂人”的审美,恕欣赏不来 0
吴卓林:你父母固然有责任,但最大的责任还是在于你自己! 0
《天乩战之白蛇传说》赵雅芝和杨紫饰演母女会擦出什么样的火花? 0
他是最帅反派专业户,演《古惑仔》大火,今病魔缠身可怜无人识! 0
如果今年勇士夺冠,下赛季詹姆斯何去何从? 1
超级替补!科斯塔本赛季替补出场贡献7次助攻 1
骑士6天里发生了啥?从首轮抢七到次轮3-0猛龙 1
如果朗多进入转会市场,哪些球队适合他? 1
詹姆斯G3决杀,你怎么看? 1

导入所需要的包:

import  time
import torch
import numpy as np
import warnings
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertModel, BertConfig, BertTokenizer, AdamW, get_cosine_schedule_with_warmup
warnings.filterwarnings('ignore')

超参数配置:

bert_path = "bert_model/s"    # 该文件夹下存放三个文件('vocab.txt', 'pytorch_model.bin', 'config.json')
tokenizer = BertTokenizer.from_pretrained(bert_path)   # 初始化分词器input_ids, input_masks, input_types, = [], [], []  # input char ids, segment type ids,  attention mask
labels = []  # 标签
maxlen = 128
EPOCHS = 300
BATCH_SIZE = 128  # 如果会出现OOM问题,减小它

数据处理部分:

with open("new_text.txt", 'r', encoding='utf-8') as f:for i in f:title, y = i.replace('\n', '').split(' ')[0], i.replace('\n', '').split(' ')[1]# encode_plus会输出一个字典,分别为'input_ids', 'token_type_ids', 'attention_mask'对应的编码# 根据参数会短则补齐,长则切断encode_dict = tokenizer.encode_plus(text=title, max_length=maxlen,padding='max_length', truncation=True)input_ids.append(encode_dict['input_ids'])input_types.append(encode_dict['token_type_ids'])input_masks.append(encode_dict['attention_mask'])labels.append(int(y))input_ids, input_types, input_masks = np.array(input_ids), np.array(input_types), np.array(input_masks)
labels = np.array(labels)
print(input_ids.shape, input_types.shape, input_masks.shape, labels.shape)# 随机打乱索引
idxes = np.arange(input_ids.shape[0])
np.random.seed(2019)   # 固定种子
np.random.shuffle(idxes)
print(idxes.shape, idxes[:10])# 8:1:1 划分训练集、验证集、测试集
input_ids_train, input_ids_valid, input_ids_test = input_ids[idxes[:186959]], input_ids[idxes[186959:210329]], input_ids[idxes[210329:]]
input_masks_train, input_masks_valid, input_masks_test = input_masks[idxes[:186959]], input_masks[idxes[186959:210329]], input_masks[idxes[210329:]]
input_types_train, input_types_valid, input_types_test = input_types[idxes[:186959]], input_types[idxes[186959:210329]], input_types[idxes[210329:]]y_train, y_valid, y_test = labels[idxes[:186959]], labels[idxes[186959:210329]], labels[idxes[210329:]]print(input_ids_train.shape,y_train.shape,input_ids_valid.shape,y_valid.shape,input_ids_test.shape,y_test.shape
)# 训练集
train_data = TensorDataset(torch.LongTensor(input_ids_train),torch.LongTensor(input_masks_train),torch.LongTensor(input_types_train),torch.LongTensor(y_train))
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)# 验证集
valid_data = TensorDataset(torch.LongTensor(input_ids_valid),torch.LongTensor(input_masks_valid),torch.LongTensor(input_types_valid),torch.LongTensor(y_valid))
valid_sampler = SequentialSampler(valid_data)
valid_loader = DataLoader(valid_data, sampler=valid_sampler, batch_size=BATCH_SIZE)# 测试集
test_data = TensorDataset(torch.LongTensor(input_ids_test),torch.LongTensor(input_masks_test),torch.LongTensor(input_types_test))
test_sampler = SequentialSampler(test_data)
test_loader = DataLoader(test_data, sampler=test_sampler, batch_size=BATCH_SIZE)

定义Model:

# 定义model
class Bert_Model(nn.Module):def __init__(self, bert_path, classes=10):super(Bert_Model, self).__init__()self.config = BertConfig.from_pretrained(bert_path)  # 导入模型超参数self.bert = BertModel.from_pretrained(bert_path)  # 加载预训练模型权重self.fc = nn.Linear(self.config.hidden_size, classes)  # 直接分类def forward(self, input_ids, attention_mask=None, token_type_ids=None):outputs = self.bert(input_ids, attention_mask, token_type_ids)out_pool = outputs[1]  # 池化后的输出 logit = self.fc(out_pool)  return logit

优化器:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
model = Bert_Model(bert_path).to(DEVICE)optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4) #AdamW优化器
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=len(train_loader),num_training_steps=EPOCHS*len(train_loader))

训练与评估模型:

def train_and_eval(model, train_loader, valid_loader,optimizer, scheduler, device, epoch):best_acc = 0.0criterion = nn.CrossEntropyLoss()for i in range(epoch):"""训练模型"""start = time.time()model.train()print("****** Running training epoch {} ******".format(i + 1))train_loss_sum = 0.0for idx, (ids, att, tpe, y) in enumerate(train_loader):ids, att, tpe, y = ids.to(device), att.to(device), tpe.to(device), y.to(device)y_pred = model(ids, att, tpe)loss = criterion(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()  # 学习率变化train_loss_sum += loss.item()if (idx + 1) % (len(train_loader) // 5) == 0:  # 只打印五次结果print("Epoch {:04d} | Step {:04d}/{:04d} | Loss {:.4f} | Time {:.4f}".format(i + 1, idx + 1, len(train_loader), train_loss_sum / (idx + 1), time.time() - start))"""验证模型"""model.eval()acc = evaluate(model, valid_loader, device)  # 验证模型的性能## 保存最优模型if acc > best_acc:best_acc = acctorch.save(model.state_dict(), "best_model.pth")print("current acc is {:.4f}, best acc is {:.4f}".format(acc, best_acc))print("time costed = {}s \n".format(round(time.time() - start, 5)))

评估模型性能:

def evaluate(model, data_loader, device):model.eval()val_true, val_pred = [], []with torch.no_grad():for idx, (ids, att, tpe, y) in (enumerate(data_loader)):y_pred = model(ids.to(device), att.to(device), tpe.to(device))y_pred = torch.argmax(y_pred, dim=1).detach().cpu().numpy().tolist()val_pred.extend(y_pred)val_true.extend(y.squeeze().cpu().numpy().tolist())return accuracy_score(val_true, val_pred)  # 返回accuracy

预测:

def predict(model, data_loader, device):model.eval()val_pred = []with torch.no_grad():for idx, (ids, att, tpe) in tqdm(enumerate(data_loader)):y_pred = model(ids.to(device), att.to(device), tpe.to(device))y_pred = torch.argmax(y_pred, dim=1).detach().cpu().numpy().tolist()val_pred.extend(y_pred)return val_pred
# 训练和评估
train_and_eval(model, train_loader, valid_loader, optimizer, scheduler, DEVICE, EPOCHS)# 加载最优权重
model.load_state_dict(torch.load("best_model.pth"))
pred_test = predict(model, test_loader, DEVICE)
print("\n Test Accuracy = {} \n".format(accuracy_score(y_test, pred_test)))
print(classification_report(y_test, pred_test, digits=4))

训练过程:

预测结果:

数据集下载:
链接:https://pan.baidu.com/s/1JrYI6mEp0DFtDyYxDgHrow
提取码:p9yh

人工智能NLP自然语言之基础篇文本分类pytorch-transformers实现BERT文本分类bert相关推荐

  1. 自然语言处理——基础篇01

    自然语言处理--基础篇01 一.什么是自然语言处理? 二.自然语言处理的难点与特点? 三.语言模型 四.NLP的常见任务类型 1. 中文分词 2. 子词切分(Subword) 3. 句法分析 4. 语 ...

  2. [NLP] 自然语言处理基础任务介绍

    概述 本文主要介绍NLP相关基础任务是什么,而不是怎么做.自然语言处理的一大特点是任务种类纷繁复杂,有多种划分的方式.从处理顺序的角度,可以分为底层的基础任务以及上层的应用任务.基础任务输出的结果往往 ...

  3. 个人用户永久免费,可自动升级版Excel插件,使用VSTO开发,Excel催化剂功能第15波-接入AI人工智能NLP自然语言处理...

    上回提到现在是概念化时代,马云爸爸们天天演讲各样的概念,IT世界也在讲ABC时代(A-AI人工智能,B-BigData大数据,C-Cloud Computing云计算),在2017年,大把大佬们都大谈 ...

  4. 大数据 - 文本文件数据提取工具之一 基础篇常见文本格式

    基础篇如何正确的拆分常见的文本格式, 什么样的字符能做拆分符号,理论上所有的字符都可以作为拆分符号用来拼接多列数据, 在拆分列数据的时候,数据里面不能再有这个拆分符号一样的字符串,否则数据就无法分开了 ...

  5. 【自然语言处理基础技能(NLP)】jieba中文文本处理

    1.基本分词函数与用法 # jieba.cut 以及 jieba.cut_for_search 返回的结构都是一个可迭代的 generator,可以使用 for 循环来获得分词后得到的每一个词语(un ...

  6. NLP自然语言处理—文本分类入门

    前言 NLP作为机器学习三大热门领域之一,现在的发展也是越来越完备,从2012年神经网络崛起之后,自然语言领域就迎来了春天,特别是当预训练方法横空出世之后,NLP作为最先尝到预训练甜头的先锋,可以说是 ...

  7. 自然语言处理(NLP): 12 BERT文本分类

    文章目录 BERT介绍 BERT 论文阅读 BERT用做特征提取 BERT 源码分析 BERT升级版 RoBERTa:更强大的BERT ALBERT:参数更少的BERT DistilBERT:轻量版B ...

  8. 6.自然语言处理学习笔记:Multi-head-self-attention 和Transformer基础知识 和BERT文本分类原理

    Multi-head-self-attention: 可以更细致的去发现局部信息. Transformer:   BERT文本分类原理:  

  9. NLP学习(一)基础篇

    一. 前言 2016年3月9日至15日和2017年5月23日至27日,分别在韩国首尔和中国嘉兴乌镇,韩国围棋九段棋手李世石.中国围棋九段棋手柯洁与人工智能围棋程序"阿尔法围棋"(A ...

最新文章

  1. MOS2010开发基础和集几种开发模型
  2. 建立索引和主外约束_Mysql索引原理
  3. 2- 计算机的组成,VMware使用
  4. 安卓入门系列-05常见布局之RelaiveLayout(相对布局)
  5. 华为徐直军:以持续创新加快数字化发展
  6. 【debug】json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)
  7. 批量杀死MySQL连接
  8. 什么是中药药浴?中药药浴的操作方法和注意事项
  9. Nat. Med. :婴儿生命早期肠道病毒组和细菌组的动态
  10. 手把手的教你安装PyCharm --Pycharm安装详细教程(一)(非常详细,非常....)
  11. centos7 修改 max locked memory
  12. iOS7新特性的兼容性处理方法
  13. python3 联合概率,边缘概率,贝叶斯定理(含详细推导)
  14. 4-AMBA VIP 编程接口
  15. OpenGL学习——计算机图形学作业:简单的室内场景
  16. 删库不跑路大法,真的好
  17. Memblaze发布PBlaze 4系列PCIe SSD新品 全面拥抱 NVMe
  18. C语言之计算log2
  19. android studio秘钥库文件不存在,[原]Android Studio查询SHA1的方法
  20. Android自定义-⭐️画布认识⭐️

热门文章

  1. java阿基米德螺线_JavaScript图形实例:阿基米德螺线
  2. JavaScript进阶学习-webAPI(总结)
  3. Mybatis-Mapper.xml输入输出映射
  4. 产品运营数据分析的指标有哪些?
  5. Google Glass开发初步体验
  6. 【多式联运】基于帝国企鹅算法、遗传算法、粒子群算法求解多式联运路径优化问题附matlab代码
  7. 网上订房系统的设计实现
  8. 清华大学交叉信息研究院和计算机系,清华大学交叉信息研究院和量子信息中心揭牌成立...
  9. GoEasy在Spring Boot的简单使用
  10. 软件测试人的“满分”面试回答和简历是什么样子的?快来看看