【NLP】看不懂bert没关系,用起来so easy!
作者:十方
bert的大名如雷贯耳,无论在比赛,还是实际上的应用早已普及开来。想到十方第一次跑bert模型用的框架还是paddlepaddle,那时候用自己的训练集跑bert还是比较痛苦的,不仅要看很多配置文件,预处理代码,甚至报错了都不知道怎么回事,当时十方用的是bert双塔做文本向量的语义召回。如今tf都已经更新到了2.4了,tensorflow-hub的出现更是降低了使用预训练模型的门槛,接下来带大家看下,如何花十分钟时间快速构建bert双塔召回模型。
tensorflow hub
打开tensorflow官网,找到tensorflow-hub点进去,我们就能看到各种预训练好的模型了,找到一个预训练好的模型(如下图),下载下来,如介绍所说,这是个12层,768维,12头的模型。
在往下看,我们看到有配套的预处理工具:
同样下载下来,然后我们就可以构建bert双塔了。
Bert双塔
import os
import shutil
import pickle
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization
from tensorflow.keras import *
from tqdm import tqdm
import numpy as np
import pandas as pd
import json
import re
import random# 这里读你自己的文本数据集
with open('./data/train_data.pickle', 'rb') as f:train_data = pickle.load(f)# 读数据用的generater
def train_generator():np.random.shuffle(train_data)for i in range(len(train_data)):yield train_data[i][0], train_data[i][1]# 训练数据 dataset
ds_tr = tf.data.Dataset.from_generator(train_generator, output_types=(tf.string, tf.string))# bert 双塔 dim_size是维度 model_name是下载模型的路径
def get_model(dim_size, model_name):# 下载的预处理工具路径preprocessor = hub.load('./bert_en_uncased_preprocess/3')# 左边塔的文本text_source = tf.keras.layers.Input(shape=(), dtype=tf.string)# 右边塔的文本text_target = tf.keras.layers.Input(shape=(), dtype=tf.string)tokenize = hub.KerasLayer(preprocessor.tokenize)tokenized_inputs_source = [tokenize(text_source)]tokenized_inputs_target = [tokenize(text_target)]seq_length = 512 # 这里指定你序列文本的最大长度bert_pack_inputs = hub.KerasLayer(preprocessor.bert_pack_inputs,arguments=dict(seq_length=seq_length))encoder_inputs_source = bert_pack_inputs(tokenized_inputs_source)encoder_inputs_target = bert_pack_inputs(tokenized_inputs_target)# 加载预训练参数 bert_model = hub.KerasLayer(model_name)bert_encoder_source, bert_encoder_target = bert_model(encoder_inputs_source), bert_model(encoder_inputs_target)# 这里想尝试in-batch loss# 也可以直接对 bert_encoder_source['pooled_output'], bert_encoder_target['pooled_output'] 做点积操作matrix_logit = tf.linalg.matmul(bert_encoder_source['pooled_output'], bert_encoder_target['pooled_output'], transpose_a=False, transpose_b=True)matrix_logit = matrix_logit / tf.sqrt(dim_size)model = models.Model(inputs = [text_source, text_target], outputs = [bert_encoder_source['pooled_output'], bert_encoder_target['pooled_output'], matrix_logit])return modelbert_double_tower = get_model(128.0, './small_bert_bert_en_uncased_L-2_H-128_A-2_1/3')
bert_double_tower.summary()
我们看到bert双塔模型已经构建完成:
然后定义loss,就可以训练啦!
optimizer = optimizers.Adam(learning_rate=5e-5)
loss_func_softmax = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
train_loss = metrics.Mean(name='train_loss')
train_acc = metrics.CategoricalAccuracy(name='train_accuracy')def train_step(model, features):with tf.GradientTape() as tape:p_source, p_target, pred = model(features)label = tf.eye(tf.shape(pred)[0])loss = loss_func_softmax(label, pred)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))train_loss.update_state(loss)train_acc.update_state(label, pred)def train_model(model, bz, epochs):for epoch in tf.range(epochs):steps = 0for feature in ds_tr.prefetch(buffer_size = tf.data.experimental.AUTOTUNE).batch(bz):logs_s = 'At Epoch={},STEP={}'tf.print(tf.strings.format(logs_s,(epoch, steps)))train_step(model, feature)steps += 1train_loss.reset_states()train_acc.reset_states()
往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑
本站qq群851320808,加入微信群请扫码:
【NLP】看不懂bert没关系,用起来so easy!相关推荐
- 小师妹学JVM之:深入理解JIT和编译优化-你看不懂系列
文章目录 简介 JIT编译器 Tiered Compilation分层编译 OSR(On-Stack Replacement) Deoptimization 常见的编译优化举例 Inlining内联 ...
- 计算机教学得意之处,看不懂没关系,知道厉害就行了:中科大俩教授11年解了两道数学难题...
王兵教授解释"哈密尔顿-田"猜想的大致原理.新华每日电讯记者陈诺摄 新华社北京11月16日电(记者徐海涛.陈诺)11月16日,<新华每日电讯>刊载题为<穿越11年 ...
- 看不懂没关系, 知道厉害就行了! 中科大俩教授11年解了两道数学难题
"著名学者弗里曼·戴森说,有些数学家是鸟,有些是青蛙.飞鸟可以俯瞰延伸至遥远地平线的数学远景,青蛙则乐于深入探讨特定问题的细节.至于我们,就像是池塘边碰巧发现美丽花朵的青蛙." 近 ...
- NLP突破性成果 BERT 模型详细解读 bert参数微调
https://zhuanlan.zhihu.com/p/46997268 NLP突破性成果 BERT 模型详细解读 章鱼小丸子 不懂算法的产品经理不是好的程序员 关注她 82 人赞了该文章 Goo ...
- 【NLP笔记】Bert浅析
作者:20届 ERIC 写在前面:本人刚刚入门NLP,希望通过记录博客来巩固自己的知识,增进对知识的理解. 在之前的博客,我们进行了CRF的原理探寻以及借助CRF工具包实现各类序列标注任务,如中文分词 ...
- 学习Linux命令神器-看不懂直接给你解释
大家都知道,Linux 系统有非常多的命令,而且每个命令又有非常多的用法,想要全部记住所有命令的所有用法,恐怕是一件不可能完成的任务. 一般情况下,我们学习一个命令时,要么直接百度去搜索它的用法,要么 ...
- 硬盘mdr转换成gdp linux,Linux 命令学习神器!命令看不懂直接给你解释!
原标题:Linux 命令学习神器!命令看不懂直接给你解释! 转自: 良许Linux 大家都知道,Linux 系统有非常多的命令,而且每个命令又有非常多的用法,想要全部记住所有命令的所有用法,恐怕是一件 ...
- 【NLP】Google BERT详解
版权声明:博文千万条,版权第一条.转载不规范,博主两行泪 https://blog.csdn.net/qq_39521554/article/details/83062188 </div> ...
- 你以为这样写代码很6,但我看不懂
来源 | 沉默王二 责编| Carol 封图| CSDN│下载于视觉中国 为了提高 Java 编程的技艺,作者最近在 GitHub 上学习一些高手编写的代码.下面这一行代码(出自大牛之手)据说可以征服 ...
最新文章
- 深度学习布料交换:在Keras中实现条件类比GAN
- JS的编码:escape,encodeURI,encodeURIComponent,解码:unescape,decodeURI,decodeURIComp
- Python中的str与unicode处理方法
- C++ 异常变量的生命周期
- 阿里RocketMQ Quick Start
- 【程序设计】程序错误类型
- 58 MM配置-评估和科目设置-OBYC配置自动记账
- 苹果新品又要来了 下周可能推出AirPods Studio
- 18-一种准确高效的领域知识图谱构建方法
- 不是所有的美剧都适合学英语
- pandas concat “InvalidIndexError: Reindexing only valid with uniquely valued Index objects“
- 你真的理解 Kubernetes 中的 requests 和 limits 吗?
- 小米路由r2d论坛_小米路由R2D,拼夕夕翻车了没
- no such instruction问题
- 一些关于SLG手游的想法
- github如何上传代码到仓库(从本地上传代码到github)
- 8800DF020SK3N1D1E3M5艾默生涡街流量计
- 用P5 JS绘制二维动画场景——静态篇
- 晋中市中等职业学校技能大赛
- 51单片机电机测速程序c语言,单片机仿真编码器电机测速程序