基于Bert文本分类进行行业识别
NLP学习之Bert文本分类
- 行业识别——基于Bert
- 项目介绍
- 数据集:
- 数据迭代器:
- 项目结构:
- 总入口:
- 模型搭建和配置
- 配置类: config
- 模型搭建:model
- 数据预处理:
- 数据预处理
- 数据迭代器构建
- 构建训练流程
- 训练
- 验证
- 测试
- 预测
行业识别——基于Bert
项目介绍
数据集:
本项目使用的是THUCNews的一个子集,每条数据都是从新闻中抽取的标题,属于标题(短文本)分类。
文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。
类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。
数据集 | 数据量 |
---|---|
训练集 | 18万 |
验证集 | 1万 |
测试集 | 1万 |
数据迭代器:
在进行训练的时候,读取数据有两种方式。
一种是提前把数据预处理好,保存为文件,训练的时候读取文件进行训练。
另一种是构建数据迭代器,预处理和训练同时进行。
优点:当数据量大的时候,一次只会加载1个batch的数据到显存中,有效防止了显存溢出。
项目结构:
│ predict.py 预测代码
│ run.py 总入口
│ train_eval.py 训练、验证、测试代码
│ utils.py 数据预处理
│
├─bert_pretrain
│ bert_config.json 超参数配置文件
│ pytorch_model.bin 预训练参数文件
│ vocab.txt 词表文件
│
├─models
│ bert.py 模型定义及超参数定义
│
└─THUCNews
├─data
│ class.txt 类别
│ dev.txt 验证集
│ test.txt 测试集
│ train.txt 验证集
│
└─saved_dict
bert.ckpt 训练模型保存
总入口:
parser = argparse.ArgumentParser(description="chinese text classification")
parser.add_argument('--model',type=str,required=True,help="choose model")
args = parser.parse_args()if __name__ == "__main__":dataset = 'THUCNews' #数据集model_name = args.model #模型名字x = import_module('models.' + model_name) #根据模型名字,获取models包下的文件config = x.Config(dataset) #模型配置类#设置随机种子,np,cpu,gpu,固定卷积层算法np.random.seed(1)torch.manual_seed(1)torch.cuda.manual_seed(1)torch.backends.cudnn.deterministic = Truestat_time = time.time()print("Loading data...")#数据预处理train_data,dev_data,test_data = build_dataset(config)#构建训练集、验证集、测试集迭代器train_iter = build_iterator(train_data,config)dev_iter = build_iterator(dev_data,config)test_iter = build_iterator(test_data,config)time_dif = get_time_dif(stat_time)print("Time usage:",time_dif)#构建模型对象,to_devicemodel = x.Model(config).to(config.device)train(config,model,train_iter,dev_iter,test_iter)
模型搭建和配置
配置类: config
class Config(object):def __init__(self,dataset):self.model_name = 'bert'#训练集、验证集、测试集self.train_path = dataset + 'data/train.txt'self.dev_path = dataset + 'data/dev.txt'self.test_path = dataset + 'data/test.txt'#类别self.class_list = [x.strip() for x in open(dataset + 'data/class.txt').readlines()]#模型保存位置self.save_path = dataset + 'saved_dict' + self.model_name + '.ckpt'#设置训练使用cpu、gpuself.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#设置xx个batch没有改变则提前结束self.require_improvement = 1000#类别数目self.num_classes = len(self.class_list)#迭代次数self.epoches = 3#设置batch_sizeself.batch_size = 128#设置句子长度self.pad_size = 32#设置学习率self.learning_rate = 5e-5#预训练模型相关文件:1.模型文件.bin 2.配置文件.json 3.词表文件vocab.txtself.bert_path = './bert_pretrain'#序列划分工具self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)#隐藏层数量self.hidden_size = 768
模型搭建:model
class model(nn.Module):def __init__(self,config):super(Module,self).__init__#加载模型self.bert = BertModel.from_pretrained(config.bert_path)#微调for param in self.bert.parameters():param.requires_grad = True #训练时进行梯度更新#输出,自定义全连接层self.fc = nn.Linear(config.hidden_size,config.num_classes)def forward(self,x):context = x[0]mask = x[2]_,pooled = self.bert(context,mask,output_all_encoded_layers=False)#是否将bert中每层(12)都输出,false只输出最后一层(128,768)out = self.fc(pooled)return out
数据预处理:
数据预处理
PAD, CLS = '[PAD]', '[CLS]' #pad:占位符,input长度相同 #cls:放在句子首位,用于分类任务def build_dataset(config):def load_dataset(path,pad_size=32): #根据pad_size进行补全或者截断contents = []with open(path,'r',encoding="utf-8") as f:for line in tqdm(f):lin = line.strip()#去除首尾空格和换行if not lin:continuecontent, label = lin.split("\t")#根据tab键进行分割token = config.tokenizer.tokenize(content) #分字,bert内置的token = [CLS] + token #头部加入seq_len = len(token)mask = [] #区分填充部分和非填充部分token_ids = config.tokenize.convert_tokens_to_ids(token) #基于词表文件,将token转换为索引#长截短补if pad_size:if len(token) < pad_size:mask = [1] * len(token) + [0] * (pad_size - len(token))token_ids += [0]*(pad_size - len(token))else:mask = [1] * pad_sizetoken_ids = token_ids[:pad_size]seq_len = pad_sizecontents.append(token_ids,int(label),seq_len,mask)return contentstrain = load_dataset(config.train_path,config.pad_size)dev = load_dataset(config.dev_path,config.pad_size)test = load_dataset(config.test_path,config.pad_size)return train, dev,test
数据迭代器构建
class DatasetIterater(object):def __init__(self, batches, batch_size, device):self.batch_size = batch_sizeself.batches = batchesself.n_batches = len(batches)//batch_sizeself.residue = Falseif len(batches) % self.n_batches != 0: #不是整数batchself.residue = Trueself.device = deviceself.index = 0def _to_tensor(self, datas):#将索引,标签,长度,mask转换为tensor类型x = torch.LongTensor(_[0] for _ in datas).to(self.device)y = torch.LongTensor(_[1] for _ in datas).to(self.device)seq_len = torch.LongTensor(_[2] for _ in datas).to(self.device)mask = torch.LongTensor(_[3] for _ in datas).to(self.device)return (x,seq_len,mask),ydef __next__(self):if self.residue and self.index == self.n_batches:batches = self.batches[self.index*self.index:len(self.batches)]self.index += 1batches = self._to_tensor(batches)return batcheselif self.index>self.n_batches:self.index = 0raise StopIterationelse:batches = self.batch_size[self.index*self.batch_size:(self.index+1)*self.batch_size]self.index += 1batches = self._to_tensor(batches)return batchesdef __iter__(self):return selfdef __len__(self):if self.residue:return self.n_batches + 1else:return self.n_batches
def build_iterator(dataset, config):iter = DatasetIterater(dataset,config.batch_size,config.device)return iter
构建训练流程
训练
def train(config,model,train_iter,dev_iter,test_iter):start_time =time.time()#开启训练模式model.train()#参数param_optimizer = list(model.named_parameters()) #无需更新的参数no_decay = ['bias','LayerNorm.bias','LayerNorm.weight']#设置哪些参数需要更新,哪些不需要optimizer_grouped_parameters = [{'params':[p for n,p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay':0.01},{'params':[p for n,p in param_optimizer if any(nd in n for nd in no_decay)],'weight_decay':0.0}]#优化器搭建optimizer = BertAdam(optimizer_grouped_parameters,lr=config.learning_rate,warmup=0.05,t_total=len(train_iter)*config.num_epochs)#记录总batchtotal_batch = 0#记录验证集的损失值dev_best_loss = float('inf')#记录上次改变的batch数last_improve = 0#训练结束标识flag = Falsemodel.train()for epoch in range(config.num_epochs):print('Epoch [{}/{}]'.format(epoch+1,config.num_epochs))for i,(trains,labels) in enumerate(train_iter):#前向传播,获取输出outputs = model(trains)#梯度置零,否则会梯度累加model.zero_grad()#交叉熵损失loss = F.cross_entropy(outputs,labels)#计算梯度loss.backward()#反向传播更新参数optimizer.step()#每一百个batch进行输出结果if total_batch % 100 == 0:#数据迁移到cpu上进行预测true = labels.data.cpu()predict = torch.max(outputs.data,1)[1].cpu()#分类指标的文本报告:1.精确率 2.召回率 3.F1 scoretrain_acc = metrics.accuracy_score(true,predict)#验证集准确率和损失dev_acc,dev_loss = evaluate(config,model,dev_iter)#损失值降低,保存模型if dev_loss < dev_best_loss:dev_best_loss =dev_losstorch.save(model.state_dict(),config.save_path)improve = '*'last_improve = total_batchelse:improve = ''time_dif = get_time_dif(start_time)msg = "Iter: {0:>6}, Train loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}"print(msg.format(total_batch,loss.item(),train_acc,dev_loss,dev_acc,time_dif,improve))model.train()total_batch += 1#早停策略if total_batch - last_improve > config.require_improvement:print("Early stopping")flag = Truebreakif flag:breaktest(config,model,test_iter)
验证
def evaluate(config,model,data_iter,test=False):model.eval()loss_tatal = 0predict_tatal = np.array([],dtype=int)label_tatal = np.array([],dtype=int)with torch.no_grad:for texts,labels in data_iter:output = model(texts)loss = F.cross_entropy(output,labels)loss_tatal += losslabels = labels.data.cpu().numpy()predict = torch.max(output.data,1)[1].cpu().numpy()label_tatal = np.append(label_tatal,labels)predict_tatal = np.append(predict_tatal,predict)acc = metrics.accuracy_score(label_tatal,predict_tatal)if test:report = metrics.classification_report(predict_tatal,label_tatal,target_names=config.class_list,digits=4)confution = metrics.confusion_matrix(predict_tatal,label_tatal)return acc,loss/len(data_iter),report,confutionreturn acc,loss/len(data_iter)
测试
def test(model,config,test_iter):model.load_state_dict(torch.load(config.save_path))model.eval()start_time = time.time()test_acc,test_loss,test_report,test_confution = evaluate(config,model,test_iter,test=True)msg = "Test loss:{0:>5.2},Test acc:{1:>6.2%}"print(msg.format(test_acc,test_loss))print("test_report")print(test_report)print("confution")print(test_confution)time_dif =get_time_dif(start_time)print("use time",time_dif)
预测
import torch
from importlib import import_module
import os
key = {0: '金融',1: '房产',2: '股票',3: '教育',4: '科技',5: '社会',6: '政治',7: '体育',8: '游戏',9: '娱乐'
}
cru = os.path.dirname(__file__)
path = os.path.join(cru,'THUCNews')
model_name = 'bert'
x = import_module('bert_demo.models.' + model_name)
config = x.Config(path)
model = x.Model(config).to("cpu")
model.load_state_dict(torch.load(config.save_path, map_location='cpu'))def build_predict_text(text):token = config.tokenizer.tokenize(text)token = ['[CLS]'] + tokenseq_len = len(token)mask = []token_ids = config.tokenizer.convert_tokens_to_ids(token)pad_size = config.pad_sizeif pad_size:if len(token) < pad_size:mask = [1] * len(token_ids) + ([0] * (pad_size - len(token)))token_ids += ([0] * (pad_size - len(token)))else:mask = [1] * pad_sizetoken_ids = token_ids[:pad_size]seq_len = pad_sizeids = torch.LongTensor([token_ids])seq_len = torch.LongTensor([seq_len])mask = torch.LongTensor([mask])return ids, seq_len, maskdef predict(text):data = build_predict_text(text)with torch.no_grad():outputs = model(data)num = torch.argmax(outputs)return key[int(num)]if __name__ == '__main__':while True:print(predict("福建省政务云平台基础设施运维服务项25555年招标公告"))
基于Bert文本分类进行行业识别相关推荐
- 【项目调研+论文阅读】基于BERT的中文命名实体识别方法[J] | day6
<基于BERT的中文命名实体识别方法>王子牛 2019-<计算机科学> 文章目录 一.相关工作 二.具体步骤 1.Bi-LSTM 2.CRF结构 三.相关实验 1.数据集 2. ...
- pytorch bert文本分类_一起读Bert文本分类代码 (pytorch篇 四)
Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了.这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码exa ...
- Transformer课程 第7课Gavin大咖 BERT文本分类-BERT Fine-Tuning
Transformer课程 第7课Gavin大咖 BERT文本分类-BERT Fine-Tuning Part III - BERT Fine-Tuning 4. Train Our Classifi ...
- 6.自然语言处理学习笔记:Multi-head-self-attention 和Transformer基础知识 和BERT文本分类原理
Multi-head-self-attention: 可以更细致的去发现局部信息. Transformer: BERT文本分类原理:
- 手把手教你搭建Bert文本分类模型,快点看过来吧!
1 赛题名称 基于文本挖掘的企业隐患排查质量分析模型 2 赛题背景 企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义.企业在填报隐患时,往往存在不认真填报的情况,"虚报.假 ...
- BERT文本分类实战
一.简介 在开始使用之前,我们先简单介绍一下到底什么是BERT,大家也可以去BERT的github上进行详细的了解.在CV问题中,目前已经有了很多成熟的预训练模型供大家使用,我们只需要修改结尾的FC层 ...
- 实体识别(4) -基于Bert进行商品标题实体识别[很详细]
基于Bert进行实体识别任务微调 致Great,ChallengeHub公众号,微信:1185918903,备注NLP技术交流 和鲸主页:https://www.heywhale.com/home/u ...
- 基于BERT的英文query实体识别模型
导读 本文主要从三个方面来介绍下本人2021年的上半年的工作,优化跨境电商的搜索相关性中的一块工作:搜索query实体理解,因为只有做好query实体理解这第一步,才能继续做很多后面相关性的事情.比如 ...
- 基于bert的分类笔记
文章目录 一.基于prompt的文本分类 二.什么是样本不均衡问题 三.样本不均衡会导致什么问题 三.如何解决样本不均衡问题 四.基于bert的文本分类模型是咋做的 五.bert模型中的[CLS].[ ...
最新文章
- 转换前台javascript传递过来的时间字符串到.net的DateTime
- ML基石_11_HazardOfOverfitting
- SSH-KeyGen 的用法 【转载】
- 告别2019,写给2020:干好技术,要把握好时光里的每一步
- mysql如何定位到数据_如何快速定位当前数据库消耗CPU最高的sql语句?
- 将所有文件从目录复制到Python中的另一个目录
- linux syslog日志
- GNU make manual 翻译( 一百五十五)
- 【MySQL】MySQL 8 PROCEDURE ANALYSE命令使用
- MuiPlayer视频播放组件入门
- mysql开启远程可连接
- 图解HTTPS协议加密解密全过程
- python100例图案_python100例 21-30
- 走迷宫(三):在XX限制条件下,是否走得出。
- Calendar类方法——编写万年历的两种方式
- 配置TURN服务器实现NAT穿透
- C++ STL库(6)
- 支付宝小程序获取手机号授权
- “杀京东”京东价格监控软件项目开发日志一
- 合思·易快报,奔向费控「自动驾驶」无人区