项目简介

最近在学习pytorch和Bert,所以做了一个这样完全新手向的入门项目来练习。

由于之前在网上学习发现现存的教程比较少,所以记录一下自己的学习过程,加深印象,也希望能帮到别的学习者吧,能涨粉丝就更好了。

项目模块化之类的做的并不好,感觉机器学习项目那么小没啥分包的必要,分了一堆包只会降低代码的可读性,反而不利于新手学习(其实是懒得分)

数据集和任务介绍

数据集就是谷歌官方在bert教程中提供得一个GLUE数据集,名字叫MRPC
下载地址:https://github.com/google-research/bert

任务目标就是判断两句输入是否是同义句子

项目结构


很简单的结构,下面会尽量分模块详细说明。
如果不想看模块,最后会直接把三个文件的内容都发出来供参考。

数据预处理模块 MRPCDataset

from torch.utils.data import Dataset
from itertools import isliceclass MRPCDataset(Dataset):def __init__(self):file = open('data/msr_paraphrase_train.txt', 'r', encoding="utf-8")data = []for line in islice(file, 1, None):content = line.split(sep='\t')data.append(content)self.data = datadef __getitem__(self, index):#把标签从char转成int,不然做损失的时候会报错if self.data[index][0] == '0':label = 0else:label = 1#这里要手动拼接句子,中间加一个拼接tokensentences = self.data[index][3] + '[SEP]' + self.data[index][4]return sentences, labeldef __len__(self):return len(self.data)

Dataset就是pytorch封装的用来加载数据的接口,将其继承后并重写__getitem__方法就可以实现一个自定义的读取器。

似乎不用dataset类也能自己实现读取器,但是感觉官方的应该会有加载速度的优化,所以还是选了官方的。

模型搭建部分

bert模型

from transformers import BertTokenizer, BertModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)

因为谷歌官方用的是tensorflow的模型,这里使用的是hugging face提供的接口进行的加载的(民间SOTA)。

还挺好用的,两行就把预训练的模型加载进来了。

全连接分类层

import torchclass FCModel(torch.nn.Module):def __init__(self):super(FCModel, self).__init__()self.fc = torch.nn.Linear(in_features=768, out_features=1)def forward(self, input):score = self.fc(input)result = torch.sigmoid(score)return result

这里是把模型封装成类并继承了Module,这样比较方便后面的换GPU、传参数等操作。其实模型很简单,就是一个全连接+sigmoid

优化方法模块

#定义优化器&损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
bert_optimizer = torch.optim.Adam(bert_model.parameters(), lr=0.00001)
crit = torch.nn.BCELoss()

这里就是定义优化器和损失函数,但是有一个巨大的坑。
全连接层和bert模型的学习率是不一样的,因为bert只需要微调,所以学习率应该很低才对。这也是为什么要把全连接层独立封装出来。

开始不懂这一点,调了一天也没收敛,后来多亏大佬指点。

训练模块

#定义训练方法
def train():#记录统计信息epoch_loss, epoch_acc = 0., 0.total_len = 0#分batch进行训练for i, data in enumerate(train_loader):bert_model.train()model.train()#获取训练数据sentence, label = datalabel = label.cuda()#数据喂给模型,并把模型拼起来encoding = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True)bert_output = bert_model(**encoding.to(device))pooler_output = bert_output.pooler_outputpredict = model(pooler_output).squeeze()#计算损失和准确度loss = crit(predict, label.float())acc = binary_accuracy(predict, label) #自定义方法#gdoptimizer.zero_grad() #把梯度重置为零bert_optimizer.zero_grad()loss.backward() #求导optimizer.step() #更新模型bert_optimizer.step()epoch_loss += loss * len(label)epoch_acc += acc * len(label)total_len += len(label)print("batch %d loss:%f accuracy:%f" % (i, loss, acc))return epoch_loss/total_len, epoch_acc/total_len

这里就都是一些常规的操作:喂数据、记录训练信息、求导、更新…
因为只是截取部分源码,所以会出现一些没声明过的变量,比如model啥的,最好配合完整代码看。

运行模块

#开始训练
Num_Epoch = 3
index = 0
for epoch in range(Num_Epoch):epoch_loss, epoch_acc = train()index += 1print("EPOCH %d loss:%f accuracy:%f" % (index, epoch_loss, epoch_acc))

这里也没什么好说的,就是定一下训练轮次。

完整项目应该还有验证模块,懒得写了

和训练模块大同小异

训练结果

因为在图书馆,跑代码机器会发出很大的噪音影响别的同学学习,所以就不跑一遍再截图了。(其实也是懒得跑)

我之前自己跑3轮训练,准确率能到达88%左右,提供参考。
(谷歌官方教程里也说这个数据集差不多也就跑90%,不用太较真)

完整项目代码

train.py

import torch
from torch.utils.data import DataLoader
from FCModel import FCModel
from MRPCDataset import MRPCDataset
from transformers import BertTokenizer, BertModel#载入数据预处理模块
mrpcDataset = MRPCDataset()
train_loader = DataLoader(dataset=mrpcDataset, batch_size=32, shuffle=True)
print("数据载入完成")#设置运行设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("设备配置完成")#加载bert模型
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)
print("bert层模型创建完成")#创建模型对象
model = FCModel()
model = model.to(device)
print("全连接层模型创建完成")#定义优化器&损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
bert_optimizer = torch.optim.Adam(bert_model.parameters(), lr=0.001)
crit = torch.nn.BCELoss()#计算准确率的公式
def binary_accuracy(predict, label):rounded_predict = torch.round(predict)correct = (rounded_predict == label).float()accuracy = correct.sum() / len(correct)return accuracy#定义训练方法
def train():#记录统计信息epoch_loss, epoch_acc = 0., 0.total_len = 0#分batch进行训练for i, data in enumerate(train_loader):bert_model.train()model.train()sentence, label = datalabel = label.cuda()encoding = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True)bert_output = bert_model(**encoding.to(device))pooler_output = bert_output.pooler_outputpredict = model(pooler_output).squeeze()loss = crit(predict, label.float())acc = binary_accuracy(predict, label)#gdoptimizer.zero_grad() #把梯度重置为零bert_optimizer.zero_grad()loss.backward() #求导optimizer.step() #更新模型bert_optimizer.step()epoch_loss += loss * len(label)epoch_acc += acc * len(label)total_len += len(label)print("batch %d loss:%f accuracy:%f" % (i, loss, acc))return epoch_loss/total_len, epoch_acc/total_len#开始训练
Num_Epoch = 3
index = 0
for epoch in range(Num_Epoch):epoch_loss, epoch_acc = train()index += 1print("EPOCH %d loss:%f accuracy:%f" % (index, epoch_loss, epoch_acc))

MRPCDataset.py

from torch.utils.data import Dataset
from itertools import isliceclass MRPCDataset(Dataset):def __init__(self):file = open('data/msr_paraphrase_train.txt', 'r', encoding="utf-8")data = []for line in islice(file, 1, None):content = line.split(sep='\t')data.append(content)self.data = datadef __getitem__(self, index):if self.data[index][0] == '0':label = 0else:label = 1sentences = self.data[index][3] + '[SEP]' + self.data[index][4]return sentences, labeldef __len__(self):return len(self.data)

FCModel.py

import torchclass FCModel(torch.nn.Module):def __init__(self):super(FCModel, self).__init__()self.fc = torch.nn.Linear(in_features=768, out_features=1)def forward(self, input):score = self.fc(input)result = torch.sigmoid(score)return result

[学习日志]使用pytorch 和 bert 实现一个简单的文本分类任务相关推荐

  1. AI深度学习入门与实战21 文本分类:用 Bert 做出一个优秀的文本分类模型

    在上一讲,我们一同了解了文本分类(NLP)问题中的词向量表示,以及简单的基于 CNN 的文本分类算法 TextCNN.结合之前咱们学习的 TensorFlow 或者其他框架,相信你已经可以构建出一个属 ...

  2. 深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

    文章目录 一.前期工作 1. 设置GPU 2. 导入预处理词库类 二.导入预处理词库类 三.参数设定 四.创建模型 五.训练模型函数 六.测试模型函数 七.训练模型与预测 今天给大家带来一个简单的中文 ...

  3. 使用Bert预训练模型进行中文文本分类(基于pytorch)

    前言 最近在做一个关于图书系统的项目,需要先对图书进行分类,想到Bert模型是有中文文本分类功能的,于是打算使用Bert模型进行预训练和实现下游文本分类任务 数据预处理 2.1 输入介绍 在选择数据集 ...

  4. ROS学习笔记十:用C++编写一个简单的服务和客户端

    ROS学习笔记十:用C++编写一个简单的服务和客户端 这一节主要介绍如何使用C++编写一个简单的服务和客户端节点. 编写服务节点 由于在前面的练习中,已经向beginner_tutorials软件包中 ...

  5. 何使用BERT模型实现中文的文本分类

    原文网址:https://blog.csdn.net/Real_Brilliant/article/details/84880528 如何使用BERT模型实现中文的文本分类 前言 Pytorch re ...

  6. 【文本分类】基于BERT预训练模型的灾害推文分类方法、基于BERT和RNN的新闻文本分类对比

    ·阅读摘要: 两篇论文,第一篇发表于<图学学报>,<图学学报>是核心期刊:第二篇发表于<北京印刷学院学报>,<北京印刷学院学报>没有任何标签. ·参考文 ...

  7. C++ 容器的综合应用的一个简单实例——文本查询程序

    [0. 需求] 最近在粗略学习<C++ Primer 4th>的容器内容,关联容器的章节末尾有个很不错的实例. 通过实现一个简单的文本查询程序,希望能够对C++的容器学习有更深的理解. 由 ...

  8. 【Qt5开发及实例】16、实现一个简单的文本编辑器(over)

    实现一个简单的文本编辑器 其他具体的代码基础看前面:http://blog.csdn.net/cutter_point/article/details/42839071 1.功能 这个程序又添加了文本 ...

  9. 基于 BERT 实现的情感分析(文本分类)----概念与应用

    文章目录 基于 BERT 的情感分析(文本分类) 基本概念理解 简便的编码方式: One-Hot 编码 突破: Word2Vec编码方式 新的开始: Attention 与 Transformer 模 ...

最新文章

  1. linux查找nginx目录,Linux下查看nginx安装目录
  2. 构造函数的理解(构造函数与 init 方法)
  3. python range函数怎么表示无限_Python for循环与range函数的使用详解
  4. python的cubes怎么使用_如何使用python中的opengl?
  5. 十三、JSP9大隐视对象中四个作用域的大小与作用范围
  6. OpenShift 4 - 向OpenShift添加新的SSH Key
  7. 9.11两点间距离(1636050091)
  8. 好久不写日志了,现在开始,好好写了。。
  9. wp8.1 全球化解决办法
  10. oracle语法和sql的区别吗,ORACLE和SQL语法区别归纳
  11. 快速傅立叶变换_FFT
  12. UESTC_神秘绑架案 CDOJ 881
  13. Junit +cucumber 运行报错 initiallizationError
  14. linux下编译安装
  15. MyBatis3详细教程-从入门到精通
  16. 数学建模算法与应用(一)线性规划
  17. 关于改进建议几个方面的有效实践
  18. composer如何进行安装和使用
  19. Python数据分析与应用(一)
  20. mysql sql 0填充_sql - MySQL - 如何用“0”填充前面的邮政编码?

热门文章

  1. 栈寄存器R0-R15
  2. 具有测绘资质的“八大GPS地图提供商”
  3. 一个去中心化的免费电子书共享网站 JS解码URL和编码URL
  4. 水浒传人物介绍微信小程序源码
  5. 头歌实训之python字典入门
  6. inno setup打包软件学习
  7. JMeter压力测试,五年Java开发者小米、阿里面经
  8. 新手小白亚马逊注册最全教程在此
  9. 软件工程复试常见问题 | 第一篇家常唠嗑篇
  10. Linux桌面,建立文件夹快捷方式