NLP学习之Bert文本分类

  • 行业识别——基于Bert
    • 项目介绍
      • 数据集:
      • 数据迭代器:
      • 项目结构:
    • 总入口:
    • 模型搭建和配置
      • 配置类: config
      • 模型搭建:model
    • 数据预处理:
      • 数据预处理
      • 数据迭代器构建
    • 构建训练流程
      • 训练
      • 验证
      • 测试
    • 预测

行业识别——基于Bert

项目介绍

数据集:

本项目使用的是THUCNews的一个子集,每条数据都是从新闻中抽取的标题,属于标题(短文本)分类。

文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。

类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。

数据集 数据量
训练集 18万
验证集 1万
测试集 1万

数据迭代器:

在进行训练的时候,读取数据有两种方式。

一种是提前把数据预处理好,保存为文件,训练的时候读取文件进行训练。

另一种是构建数据迭代器,预处理和训练同时进行。

优点:当数据量大的时候,一次只会加载1个batch的数据到显存中,有效防止了显存溢出。

项目结构:

│ predict.py 预测代码
│ run.py 总入口
│ train_eval.py 训练、验证、测试代码
│ utils.py 数据预处理

├─bert_pretrain
│ bert_config.json 超参数配置文件
│ pytorch_model.bin 预训练参数文件
│ vocab.txt 词表文件

├─models
│ bert.py 模型定义及超参数定义

└─THUCNews
├─data
│ class.txt 类别
│ dev.txt 验证集
│ test.txt 测试集
│ train.txt 验证集

└─saved_dict
bert.ckpt 训练模型保存

总入口:

parser = argparse.ArgumentParser(description="chinese text classification")
parser.add_argument('--model',type=str,required=True,help="choose model")
args = parser.parse_args()if __name__ == "__main__":dataset = 'THUCNews' #数据集model_name = args.model #模型名字x = import_module('models.' + model_name) #根据模型名字,获取models包下的文件config = x.Config(dataset) #模型配置类#设置随机种子,np,cpu,gpu,固定卷积层算法np.random.seed(1)torch.manual_seed(1)torch.cuda.manual_seed(1)torch.backends.cudnn.deterministic = Truestat_time = time.time()print("Loading data...")#数据预处理train_data,dev_data,test_data = build_dataset(config)#构建训练集、验证集、测试集迭代器train_iter = build_iterator(train_data,config)dev_iter = build_iterator(dev_data,config)test_iter = build_iterator(test_data,config)time_dif = get_time_dif(stat_time)print("Time usage:",time_dif)#构建模型对象,to_devicemodel = x.Model(config).to(config.device)train(config,model,train_iter,dev_iter,test_iter)

模型搭建和配置

配置类: config

class Config(object):def  __init__(self,dataset):self.model_name = 'bert'#训练集、验证集、测试集self.train_path = dataset + 'data/train.txt'self.dev_path = dataset + 'data/dev.txt'self.test_path = dataset + 'data/test.txt'#类别self.class_list = [x.strip() for x in open(dataset + 'data/class.txt').readlines()]#模型保存位置self.save_path = dataset + 'saved_dict' + self.model_name + '.ckpt'#设置训练使用cpu、gpuself.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#设置xx个batch没有改变则提前结束self.require_improvement = 1000#类别数目self.num_classes = len(self.class_list)#迭代次数self.epoches = 3#设置batch_sizeself.batch_size = 128#设置句子长度self.pad_size = 32#设置学习率self.learning_rate = 5e-5#预训练模型相关文件:1.模型文件.bin 2.配置文件.json 3.词表文件vocab.txtself.bert_path = './bert_pretrain'#序列划分工具self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)#隐藏层数量self.hidden_size = 768

模型搭建:model

class model(nn.Module):def __init__(self,config):super(Module,self).__init__#加载模型self.bert = BertModel.from_pretrained(config.bert_path)#微调for param in self.bert.parameters():param.requires_grad = True #训练时进行梯度更新#输出,自定义全连接层self.fc = nn.Linear(config.hidden_size,config.num_classes)def forward(self,x):context = x[0]mask = x[2]_,pooled = self.bert(context,mask,output_all_encoded_layers=False)#是否将bert中每层(12)都输出,false只输出最后一层(128,768)out = self.fc(pooled)return out

数据预处理:

数据预处理

PAD, CLS = '[PAD]', '[CLS]'  #pad:占位符,input长度相同   #cls:放在句子首位,用于分类任务def build_dataset(config):def load_dataset(path,pad_size=32):  #根据pad_size进行补全或者截断contents = []with open(path,'r',encoding="utf-8") as f:for line in tqdm(f):lin = line.strip()#去除首尾空格和换行if not lin:continuecontent, label = lin.split("\t")#根据tab键进行分割token = config.tokenizer.tokenize(content) #分字,bert内置的token = [CLS] + token #头部加入seq_len = len(token)mask = []  #区分填充部分和非填充部分token_ids = config.tokenize.convert_tokens_to_ids(token) #基于词表文件,将token转换为索引#长截短补if pad_size:if len(token) < pad_size:mask = [1] * len(token) + [0] * (pad_size - len(token))token_ids += [0]*(pad_size - len(token))else:mask = [1] * pad_sizetoken_ids = token_ids[:pad_size]seq_len = pad_sizecontents.append(token_ids,int(label),seq_len,mask)return contentstrain = load_dataset(config.train_path,config.pad_size)dev = load_dataset(config.dev_path,config.pad_size)test = load_dataset(config.test_path,config.pad_size)return train, dev,test

数据迭代器构建

class DatasetIterater(object):def __init__(self, batches, batch_size, device):self.batch_size = batch_sizeself.batches = batchesself.n_batches = len(batches)//batch_sizeself.residue = Falseif len(batches) % self.n_batches != 0: #不是整数batchself.residue = Trueself.device = deviceself.index = 0def  _to_tensor(self, datas):#将索引,标签,长度,mask转换为tensor类型x = torch.LongTensor(_[0] for _ in datas).to(self.device)y = torch.LongTensor(_[1] for _ in datas).to(self.device)seq_len = torch.LongTensor(_[2] for _ in datas).to(self.device)mask = torch.LongTensor(_[3] for _ in datas).to(self.device)return (x,seq_len,mask),ydef __next__(self):if self.residue and self.index == self.n_batches:batches = self.batches[self.index*self.index:len(self.batches)]self.index += 1batches = self._to_tensor(batches)return batcheselif self.index>self.n_batches:self.index = 0raise StopIterationelse:batches = self.batch_size[self.index*self.batch_size:(self.index+1)*self.batch_size]self.index += 1batches = self._to_tensor(batches)return batchesdef __iter__(self):return selfdef __len__(self):if self.residue:return self.n_batches + 1else:return self.n_batches
def build_iterator(dataset, config):iter = DatasetIterater(dataset,config.batch_size,config.device)return iter

构建训练流程

训练

def train(config,model,train_iter,dev_iter,test_iter):start_time =time.time()#开启训练模式model.train()#参数param_optimizer = list(model.named_parameters()) #无需更新的参数no_decay = ['bias','LayerNorm.bias','LayerNorm.weight']#设置哪些参数需要更新,哪些不需要optimizer_grouped_parameters = [{'params':[p for n,p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay':0.01},{'params':[p for n,p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay':0.0}]#优化器搭建optimizer = BertAdam(optimizer_grouped_parameters,lr=config.learning_rate,warmup=0.05,t_total=len(train_iter)*config.num_epochs)#记录总batchtotal_batch = 0#记录验证集的损失值dev_best_loss = float('inf')#记录上次改变的batch数last_improve = 0#训练结束标识flag = Falsemodel.train()for epoch in range(config.num_epochs):print('Epoch [{}/{}]'.format(epoch+1,config.num_epochs))for i,(trains,labels) in enumerate(train_iter):#前向传播,获取输出outputs = model(trains)#梯度置零,否则会梯度累加model.zero_grad()#交叉熵损失loss = F.cross_entropy(outputs,labels)#计算梯度loss.backward()#反向传播更新参数optimizer.step()#每一百个batch进行输出结果if total_batch % 100 == 0:#数据迁移到cpu上进行预测true = labels.data.cpu()predict = torch.max(outputs.data,1)[1].cpu()#分类指标的文本报告:1.精确率 2.召回率 3.F1 scoretrain_acc = metrics.accuracy_score(true,predict)#验证集准确率和损失dev_acc,dev_loss = evaluate(config,model,dev_iter)#损失值降低,保存模型if dev_loss < dev_best_loss:dev_best_loss =dev_losstorch.save(model.state_dict(),config.save_path)improve = '*'last_improve = total_batchelse:improve = ''time_dif = get_time_dif(start_time)msg = "Iter: {0:>6}, Train loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}"print(msg.format(total_batch,loss.item(),train_acc,dev_loss,dev_acc,time_dif,improve))model.train()total_batch += 1#早停策略if total_batch - last_improve > config.require_improvement:print("Early stopping")flag = Truebreakif flag:breaktest(config,model,test_iter)

验证

def evaluate(config,model,data_iter,test=False):model.eval()loss_tatal = 0predict_tatal = np.array([],dtype=int)label_tatal = np.array([],dtype=int)with  torch.no_grad:for texts,labels in data_iter:output = model(texts)loss = F.cross_entropy(output,labels)loss_tatal += losslabels = labels.data.cpu().numpy()predict = torch.max(output.data,1)[1].cpu().numpy()label_tatal = np.append(label_tatal,labels)predict_tatal = np.append(predict_tatal,predict)acc = metrics.accuracy_score(label_tatal,predict_tatal)if test:report = metrics.classification_report(predict_tatal,label_tatal,target_names=config.class_list,digits=4)confution = metrics.confusion_matrix(predict_tatal,label_tatal)return acc,loss/len(data_iter),report,confutionreturn acc,loss/len(data_iter)

测试

def test(model,config,test_iter):model.load_state_dict(torch.load(config.save_path))model.eval()start_time = time.time()test_acc,test_loss,test_report,test_confution = evaluate(config,model,test_iter,test=True)msg = "Test loss:{0:>5.2},Test acc:{1:>6.2%}"print(msg.format(test_acc,test_loss))print("test_report")print(test_report)print("confution")print(test_confution)time_dif =get_time_dif(start_time)print("use time",time_dif)

预测

import torch
from importlib import import_module
import os
key = {0: '金融',1: '房产',2: '股票',3: '教育',4: '科技',5: '社会',6: '政治',7: '体育',8: '游戏',9: '娱乐'
}
cru = os.path.dirname(__file__)
path = os.path.join(cru,'THUCNews')
model_name = 'bert'
x = import_module('bert_demo.models.' + model_name)
config = x.Config(path)
model = x.Model(config).to("cpu")
model.load_state_dict(torch.load(config.save_path, map_location='cpu'))def build_predict_text(text):token = config.tokenizer.tokenize(text)token = ['[CLS]'] + tokenseq_len = len(token)mask = []token_ids = config.tokenizer.convert_tokens_to_ids(token)pad_size = config.pad_sizeif pad_size:if len(token) < pad_size:mask = [1] * len(token_ids) + ([0] * (pad_size - len(token)))token_ids += ([0] * (pad_size - len(token)))else:mask = [1] * pad_sizetoken_ids = token_ids[:pad_size]seq_len = pad_sizeids = torch.LongTensor([token_ids])seq_len = torch.LongTensor([seq_len])mask = torch.LongTensor([mask])return ids, seq_len, maskdef predict(text):data = build_predict_text(text)with torch.no_grad():outputs = model(data)num = torch.argmax(outputs)return key[int(num)]if __name__ == '__main__':while True:print(predict("福建省政务云平台基础设施运维服务项25555年招标公告"))

基于Bert文本分类进行行业识别相关推荐

  1. 【项目调研+论文阅读】基于BERT的中文命名实体识别方法[J] | day6

    <基于BERT的中文命名实体识别方法>王子牛 2019-<计算机科学> 文章目录 一.相关工作 二.具体步骤 1.Bi-LSTM 2.CRF结构 三.相关实验 1.数据集 2. ...

  2. pytorch bert文本分类_一起读Bert文本分类代码 (pytorch篇 四)

    Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了.这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码exa ...

  3. Transformer课程 第7课Gavin大咖 BERT文本分类-BERT Fine-Tuning

    Transformer课程 第7课Gavin大咖 BERT文本分类-BERT Fine-Tuning Part III - BERT Fine-Tuning 4. Train Our Classifi ...

  4. 6.自然语言处理学习笔记:Multi-head-self-attention 和Transformer基础知识 和BERT文本分类原理

    Multi-head-self-attention: 可以更细致的去发现局部信息. Transformer:   BERT文本分类原理:  

  5. 手把手教你搭建Bert文本分类模型,快点看过来吧!

    1 赛题名称 基于文本挖掘的企业隐患排查质量分析模型 2 赛题背景 企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义.企业在填报隐患时,往往存在不认真填报的情况,"虚报.假 ...

  6. BERT文本分类实战

    一.简介 在开始使用之前,我们先简单介绍一下到底什么是BERT,大家也可以去BERT的github上进行详细的了解.在CV问题中,目前已经有了很多成熟的预训练模型供大家使用,我们只需要修改结尾的FC层 ...

  7. 实体识别(4) -基于Bert进行商品标题实体识别[很详细]

    基于Bert进行实体识别任务微调 致Great,ChallengeHub公众号,微信:1185918903,备注NLP技术交流 和鲸主页:https://www.heywhale.com/home/u ...

  8. 基于BERT的英文query实体识别模型

    导读 本文主要从三个方面来介绍下本人2021年的上半年的工作,优化跨境电商的搜索相关性中的一块工作:搜索query实体理解,因为只有做好query实体理解这第一步,才能继续做很多后面相关性的事情.比如 ...

  9. 基于bert的分类笔记

    文章目录 一.基于prompt的文本分类 二.什么是样本不均衡问题 三.样本不均衡会导致什么问题 三.如何解决样本不均衡问题 四.基于bert的文本分类模型是咋做的 五.bert模型中的[CLS].[ ...

最新文章

  1. 转换前台javascript传递过来的时间字符串到.net的DateTime
  2. ML基石_11_HazardOfOverfitting
  3. SSH-KeyGen 的用法 【转载】
  4. 告别2019,写给2020:干好技术,要把握好时光里的每一步
  5. mysql如何定位到数据_如何快速定位当前数据库消耗CPU最高的sql语句?
  6. 将所有文件从目录复制到Python中的另一个目录
  7. linux syslog日志
  8. GNU make manual 翻译( 一百五十五)
  9. 【MySQL】MySQL 8 PROCEDURE ANALYSE命令使用
  10. MuiPlayer视频播放组件入门
  11. mysql开启远程可连接
  12. 图解HTTPS协议加密解密全过程
  13. python100例图案_python100例 21-30
  14. 走迷宫(三):在XX限制条件下,是否走得出。
  15. Calendar类方法——编写万年历的两种方式
  16. 配置TURN服务器实现NAT穿透
  17. C++ STL库(6)
  18. 支付宝小程序获取手机号授权
  19. “杀京东”京东价格监控软件项目开发日志一
  20. 合思·易快报,奔向费控「自动驾驶」无人区

热门文章

  1. 怎么样才能锻炼好口才
  2. BufferQueue 学习总结(内附动态图)
  3. 掌握这十个Linux命令,秒变Linux老手
  4. python爬虫难点_Python爬虫技巧
  5. ViewPager的使用方法
  6. Python操作SQLServer
  7. 解决VS调试web项目启动谷歌浏览器“无标题”、“已崩溃”问题
  8. 安卓SO层开发 -- 编译指定平台的SO文件
  9. 选择小红书素人笔记推广有什么好处?
  10. RS232 小板测试