学习时间:2022.04.26~2022.04.30

文章目录

  • 7. 基于PyTorch的BERT应用实践
    • 7.1 工具选取
    • 7.2 文本预处理
    • 7.3 使用BERT模型
      • 7.3.1 数据输入及应用预处理
      • 7.3.2 提取词向量
      • 7.3.3 网络建模
      • 7.3.4 参数准备
      • 7.3.5 模型训练

7. 基于PyTorch的BERT应用实践

本节着重于将BERT模型应用到具体的实践当中,因此未来有很多可以优化的地方,比如自己重写dataset和dataloader方法,这样对于pytorch应该能有更好地、更灵活的运用。

7.1 工具选取

使用PyTorch框架应用Bert,一般是去Hugging Face (一个社区,有很多人会把训练好了的模型放在上面,可以支持PyTorch和Tenserflow)上,把Bert模型下载下来,然后再应用。

在应用具体模型之前,需要先安装一下Hugging Face提供的transformers库,因为它上面的模型都是基于这个transformers库来编写的,即使下载了模型也要通过这个库来使用。(这个是必要的)

pip install transformers

此外,数据处理部分,如果不想自己重写dataset和dataloader方法的话,Hugging Face也提供了datasets库,可以通过pip install datasets来进行安装和使用。(这个是可选的)

这两个库的使用方法,可以直接去官网查看使用文档,我觉得还是相对来说比较详细了(见Transformers和Datasets)。

此外,自己这次使用的数据集也是来自Kaggle:Quora Insincere Questions Classification | Kaggle。

7.2 文本预处理

这一部分主要参考了Kaggle上大神的代码,然后汇总了一下装进一个函数里了。以下简单举例:

  1. 将一些表情符号、html网址、email ids、urls全都清除:
def clean_data(data):punct_tag = re.compile(r'[^\w\s]')data = punct_tag.sub(r'', data)html_tag = re.compile(r'<.*?>')data = html_tag.sub(r'', data)url_clean = re.compile(r"https://\S+|www\.\S+")data = url_clean.sub(r'', data)emoji_clean = re.compile("["u"\U0001F600-\U0001F64F"  # emoticonsu"\U0001F300-\U0001F5FF"  # symbols & pictographsu"\U0001F680-\U0001F6FF"  # transport & map symbolsu"\U0001F1E0-\U0001F1FF"  # flags (iOS)u"\U00002702-\U000027B0"u"\U000024C2-\U0001F251""]+", flags=re.UNICODE)data = emoji_clean.sub(r'', data)url_clean = re.compile(r"https://\S+|www\.\S+")data = url_clean.sub(r'', data)return data
  1. 清除所有格:
def strip_possessives(text):text = text.replace("'s", '')text = text.replace('’s', '')text = text.replace("\'s", '')text = text.replace("\’s", '')return text
  1. 将数字替换成##:
def clean_numbers(x):x = re.sub("[0-9]{5,}", '#####', x)x = re.sub("[0-9]{4}", '####', x)x = re.sub("[0-9]{3}", '###', x)x = re.sub("[0-9]{2}", '##', x)return x

……

最后通过一个函数全部调用:

def texts_preprogress(df):# 应用之前所有的预处理步骤df = df.apply(lambda x: clean_data(x))df = df.apply(lambda x: expand_contractions(x))df = df.apply(lambda x: replace_typical_misspell(x))df = df.apply(lambda x: strip_possessives(x))df = df.apply(lambda x: replace_multi_exclamation_mark(x))df = df.apply(lambda x: clean_text(x))df = df.apply(lambda x: change_stopwords(x))df = df.apply(lambda x: clean_numbers(x))return df

7.3 使用BERT模型

整个使用分为5部分:数据输入及应用预处理、提取词向量、构建网络模型(嵌入BERT)、参数准备和模型训练。

个人本次使用的是bert-base-uncased模型。

7.3.1 数据输入及应用预处理

# 确定设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 导入文件并进行文本预处理
train_df = pd.read_csv('train.csv', nrows=540)
train_df['question_text'] = texts_preprogress(train_df['question_text'])
eval_df = pd.read_csv('train.csv', nrows=675)[540:675]
eval_df['question_text'] = texts_preprogress(eval_df['question_text'])

7.3.2 提取词向量

⭐这里值得特别特别特别提一下的是:tokenizer函数中的padding参数。

padding=True等价于padding='max_length',只对于句子对任务起作用,会一样自动补全到batch中的最长长度;但是!!!对于单句的任务,两者并不等价!!!这也是我一开始设置了padding=True,然后指定max_length=72,最后输出的句子却依然长短不一,然后输不进网络的原因。

所以对于单句任务,一定要先指定padding='max_length',然后再设置max_length= ,这样才会真正补全。

# 导入分词器
tokenizer = AutoTokenizer.from_pretrained('D:/Py-project/models/huggingface/bert-base-uncased/')# 定义token函数
def tokenize_function(examples):return tokenizer(examples['question_text'], padding='max_length', max_length=72, truncation=True)# padding='max_length'填充批处理中较短的序列以匹配最长的序列(太坑了!!!!);truncation=True将序列截断为模型接受的最大长度# 定义处理标签及id列的函数
def batch_label(df):df = df.drop('qid', axis=1)dataset = Dataset.from_pandas(df)for k in dataset.column_names:if k == 'target':dataset = dataset.rename_column(k, 'labels')inputs = dataset.map(tokenize_function, batched=True)  # 分词并输出词向量;batched=True开启批次输入inputs.set_format(type='torch')  # 词向量转为tensorinputs = inputs.remove_columns('question_text')  # 删除文本列,因为模型不接受原始文本作为输入dataloader = DataLoader(inputs, shuffle=True, batch_size=128)  # 创建成DataLoaderreturn dataloadertrain_dl = batch_label(train_df)
eval_dl = batch_label(eval_df)

7.3.3 网络建模

其实Bert也有直接提供一个训练好的bertclassification模型,但只能直接调用,不太好嵌入到网络,所以这里只用了bert模型,然后嵌入到了网络里面。

class my_bert(nn.Module):def __init__(self):super(my_bert, self).__init__()# Bert模型需要嵌入到网络中self.bert = BertModel.from_pretrained("D:/Py-project/models/huggingface/bert-base-uncased")# 将Bert模型的参数设置为可以更新for param in self.bert.parameters():param.requires_grad = Trueself.linear = torch.nn.Linear(768, 2)self.dropout = torch.nn.Dropout(0.5)def forward(self, x):input_ids = x['input_ids'].to(device)token_type_ids = x['token_type_ids'].to(device)attention_mask = x['attention_mask'].to(device)output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)output = output['pooler_output']  # bert的输出有4个,这个任务仅需要1个output = self.linear(output)output = self.dropout(output)output = torch.sigmoid(output)return output

7.3.4 参数准备

# 设置随机数种子,保证结果可复现
seed = 42
if device == 'cuda':torch.cuda.manual_seed(seed)
else:torch.manual_seed(seed)# 实例化模型
model = my_bert()
model.to(device)# 设置参数
lr = 2e-5
epoch = 2
show_step = 1
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

7.3.5 模型训练

这里只涉及到一个小问题:如果数据集太大了,基本都需要转成dataloader后再分批输入网络进行训练和验证。

for i in range(epoch):model.train()losses = []accuracy = []start_time = time.time()for batch in tqdm(train_dl, desc=f'第{i+1}/{epoch}次迭代进度', ncols=100):pred = model(batch)  # 正向传播label = batch['labels'].to(device)loss = criterion(pred, label)  # 计算损失函数# 存入准确率和losslosses.append(loss.item())pred_labels = torch.argmax(pred, dim=1)acc = torch.sum(pred_labels == label).item() / len(pred_labels)accuracy.append(acc)optimizer.zero_grad()  # 优化器的梯度清零loss.backward()  # 反向传播optimizer.step()  # 参数更新# 测试集评估if i % show_step == 0:  # 控制输出间隔model.eval()ev_losses = []ev_acc = []with torch.no_grad():for batch in eval_dl:ev_pred = model(batch)ev_label = batch['labels'].to(device)ev_loss = criterion(ev_pred, ev_label)# 存入准确率和lossev_losses.append(ev_loss.item())pred_labels = torch.argmax(ev_pred, dim=1)acc = torch.sum(pred_labels == ev_label).item() / len(pred_labels)ev_acc.append(acc)elapsed_time = time.time() - start_timeprint("\nEpoch: {}/{}: ".format(i+1, epoch),"Accuracy: {:.6f}; ".format(np.mean(accuracy)),"Val Accuracy: {:.6f}; ".format(np.mean(ev_acc)),"Loss: {:.6f}; ".format(np.mean(losses)),"Val Loss: {:.6f}; ".format(np.mean(ev_losses)),'Time: {:.2f}s'.format(elapsed_time))

以上就是本人的bert第一次尝试的大部分内容了,之后应该会去试试ALBERT和RoBERTa,然后也会尝试自己重写下datasets和dataloader。未来继续加油!

学习笔记:深度学习(8)——基于PyTorch的BERT应用实践相关推荐

  1. 学习笔记--深度学习入门--基于Pyrhon的理论与实现--[日]斋藤康毅 -- 持续更新中

    关于这本 "神作" 的简介 这本书上市不到 2 年,就已经印刷 10 万册了.日本人口数量不大,但是却有这么多人读过这本书,况且它不是一本写真集,是实实在在的技术书,让人觉得很不可 ...

  2. 人工智障学习笔记——深度学习(4)生成对抗网络

    概念 生成对抗网络(GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一.模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discrimi ...

  3. 人工智障学习笔记——深度学习(2)卷积神经网络

    上一章最后提到了多层神经网络(deep neural network,DNN),也叫多层感知机(Multi-Layer perceptron,MLP). 当下流行的DNN主要分为应对具有空间性分布数据 ...

  4. 人工智障学习笔记——深度学习(1)神经网络

    一.神经网络 我们所说的深度学习,其最基础最底层的模型称之为"神经网络"(neural network),因为我们希望机器能够像我们人类大脑的神经网络处理事件一样去解决问题,最终达 ...

  5. Matlab深度学习笔记——深度学习工具箱说明

    本文是Rasmus Berg Palm发布在Github上的Deep-learning toolbox的说明文件,作者对这个工具箱进行了详细的介绍(原文链接:https://github.com/ra ...

  6. 3Blue1Brown深度学习笔记 深度学习之神经网络的结构 Part 1 ver 2.0

    神经元 3B1B先讨论最简单的MLP(多层感知器),只是经典的原版,就已经能识别手写数字. 这里一开始我们把神经元看作装有数字的容器,装着一个0~1之间的数字.但是最后更准确一些,我们把神经元看作一个 ...

  7. 人工智障学习笔记——深度学习(3)递归神经网络

    传统的神经网络模型中,输入层到隐含层再到输出层他们的层与层之间是全连接的,但是每层之间的节点是无连接的.这样就会造成一个问题,有些情况,每层之间的节点可能是存在某些影响因素的.例如,你要预测句子的下一 ...

  8. 学习笔记︱深度学习以及R中并行算法的应用(GPU)

    笔记源于一次微课堂,由数据人网主办,英伟达高级工程师ParallerR原创.大牛的博客链接:http://www.parallelr.com/training/ 由于本人白痴,不能全部听懂,所以只能把 ...

  9. 深度学习笔记其七:计算机视觉和PYTORCH

    深度学习笔记其七:计算机视觉和PYTORCH 1. 图像增广 1.1 常用的图像增广方法 1.1.1 翻转和裁剪 1.1.2 改变颜色 1.1.3 结合多种图像增广方法 1.2 使用图像增广进行训练 ...

最新文章

  1. 301重定向的好处:
  2. Spring MVC - 配置Spring MVC
  3. 【Android 逆向】ART 脱壳 ( DexClassLoader 脱壳 | exec_utils.cc 中执行 Dex 编译为 Oat 文件的 Exec 和 ExecAndReturnC函数 )
  4. Dlib学习笔记:dlib array2d与 OpenCV Mat互转
  5. windows重绘机制原理
  6. 数据库的关系运算和完整性约束
  7. JQuery实现ajax跨域
  8. python tkinter获取屏幕大小_使用Python构建属于自己的Markdown编辑器
  9. java 动态爱心代码_java swing实现动态心形图案的代码下载
  10. 海康人脸服务器型号,DS-2CD7A27FWD/F-LZ(S) 海康人脸识别摄像机 海康200万像素深眸智能人脸日夜筒型网络摄像机...
  11. 微信小程序-加载图片
  12. 微信小程序开发得会议扫码签到系统
  13. 计算机教师格言座右铭,教师格言座右铭100句
  14. 【黑金ZYNQ7000系列原创视频教程】06.ZYNQ来自FPGA的中断mdash;mdash;按键中断实验...
  15. 羊皮卷的故事-第十八章
  16. 谷歌地图卫星下周发射 分辨率提高至0.5米
  17. 微型计算机的运算器控制器及内存的总称,微型计算机的运算器、控制器及内存存储器的总称是。...
  18. 解压GZIP加密后的Response
  19. 织梦cms模块插件-阿里云短信,让织梦cms更简单
  20. 【阅读记录】3DSSD:Point-based 3D Single Stage Object Detector

热门文章

  1. 机器学习——pandas基础
  2. 浅谈在线文档的那些事儿
  3. 产品安全:短信验证码的安全防护策略
  4. maven私服nexus搭建并新建仓库使用
  5. 政策面、基本面、消息面、技术面、资金面.
  6. F - Candle Box(模拟+思维)
  7. 大话设计模式:第17章 适配器模式
  8. C语言:计算全班学生的总成绩、平均成绩和以及 140 分以下的人数。
  9. 什么叫新零售系统 新零售的特点是什么?
  10. c# 下载千千静听歌词