本文根据苏剑林的基于bert的baseline进行短文本匹配的讲解,其github地址是:https://github.com/bojone/oppo-text-match/blob/main/baseline.py

赛题地址:

https://tianchi.aliyun.com/competition/entrance/531851

数据探索

下载好相关数据之后,我们先看一下数据是什么样的:

path = '/content/drive/MyDrive/oppo-text-match/baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
with open(path,'r',encoding='utf-8') as f:lines = f.readlines()for i,line in enumerate(lines):print(line.split('\t'))if i == 5:break

结果:

['1 2 3 4 5 6 7', '8 9 10 4 11', '0\n']
['12 13 14 15', '12 15 11 16', '0\n']
['17 18 12 19 20 21 22 23 24', '12 23 25 6 26 27 19', '1\n']
['28 29 30 31 11', '32 33 34 30 31', '1\n']
['29 35 36 29', '29 37 36 29', '1\n']
['38 23 39 9 40', '12 19 41 42 23 43 12 23 44 41 42 19', '0\n']

数据都是脱敏的,也就是字都用数字来表示了。
统计一下text1+text2的长度:

import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
import pandas as pd
fromtrain_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
test_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_testA_20210228.tsv'def cal_len_dis(path):with open(path,'r',encoding='utf-8') as f:lines = f.readlines()len_list = []for line in lines:line = line.strip().split('\t')len_list.append(len(line[0]+line[1]))return len_listdef get_len_detail(data):df = pd.DataFrame(data)res = df.describe()return res# 设置matplotlib正常显示中文和负号
font = FontProperties(fname=r'/data02/gob/project/text-match/simhei.ttf')
matplotlib.rcParams['axes.unicode_minus']=False     # 正常显示负号def draw_hist(data):plt.hist(data, bins=40, facecolor="blue", edgecolor="black", alpha=0.7)# 显示横轴标签plt.xlabel("长度",fontproperties=font)# 显示纵轴标签plt.ylabel("数量",fontproperties=font)# 显示图标题plt.title("句子长度统计",fontproperties=font)plt.savefig('len_hist.png')plt.show()if __name__ == '__main__':len_list = cal_len_dis(train_path)res = get_len_detail(len_list)print(res)draw_hist(len_list)


相关统计量:
count 100000.0000
mean 46.8328
std 17.317
min 12
25% 35
50% 43
75% 55
max 279

baseline中值得注意的一些代码

from bert4keras.snippets import truncate_sequences
truncate_sequences(maxlen, -1, a, b)

这个函数用于截断超过最大长度的句子,如果len(a+b)>maxlen,则对句子a进行截断。

def random_mask(text_ids):"""随机mask"""input_ids, output_ids = [], []rands = np.random.random(len(text_ids))for r, i in zip(rands, text_ids):if r < 0.15 * 0.8:input_ids.append(4)output_ids.append(i)elif r < 0.15 * 0.9:input_ids.append(i)output_ids.append(i)elif r < 0.15:input_ids.append(np.random.choice(len(tokens)) + 7)output_ids.append(i)else:input_ids.append(i)output_ids.append(0)return input_ids, output_ids

这个函数用于随机将一些词mask掉。

def sample_convert(text1, text2, label, random=False):"""转换为MLM格式"""text1_ids = [tokens.get(t, 1) for t in text1]text2_ids = [tokens.get(t, 1) for t in text2]if random:if np.random.random() < 0.5:text1_ids, text2_ids = text2_ids, text1_idstext1_ids, out1_ids = random_mask(text1_ids)text2_ids, out2_ids = random_mask(text2_ids)else:out1_ids = [0] * len(text1_ids)out2_ids = [0] * len(text2_ids)token_ids = [2] + text1_ids + [3] + text2_ids + [3]segment_ids = [0] * len(token_ids)output_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]return token_ids, segment_ids, output_ids

用于转换单个样本为bert的输入。

class data_generator(DataGenerator):"""数据生成器"""def __iter__(self, random=False):batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []for is_end, (text1, text2, label) in self.sample(random):token_ids, segment_ids, output_ids = sample_convert(text1, text2, label, random)batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)batch_output_ids.append(output_ids)if len(batch_token_ids) == self.batch_size or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)batch_output_ids = sequence_padding(batch_output_ids)yield [batch_token_ids, batch_segment_ids], batch_output_idsbatch_token_ids, batch_segment_ids, batch_output_ids = [], [], []

用于将多个样本制作成batch的格式。

# 加载预训练模型
model = build_transformer_model(config_path=config_path,checkpoint_path=checkpoint_path,with_mlm=True,keep_tokens=[0, 100, 101, 102, 103, 100, 100] + keep_tokens[:len(tokens)]
)

这个函数用于加载预训练的模型

def masked_crossentropy(y_true, y_pred):"""mask掉非预测部分"""y_true = K.reshape(y_true, K.shape(y_true)[:2])y_mask = K.cast(K.greater(y_true, 0.5), K.floatx())loss = K.sparse_categorical_crossentropy(y_true, y_pred)loss = K.sum(loss * y_mask) / K.sum(y_mask)return loss[None, None]

计算损失。

model.compile(loss=masked_crossentropy, optimizer=Adam(1e-5))
model.summary()

定义优化器和损失。

# 转换数据集
train_generator = data_generator(train_data, batch_size)
valid_generator = data_generator(valid_data, batch_size)
test_generator = data_generator(test_data, batch_size)def evaluate(data):"""线下评测函数"""Y_true, Y_pred = [], []for x_true, y_true in data:y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)y_true = y_true[:, 0] - 5Y_pred.extend(y_pred)Y_true.extend(y_true)return roc_auc_score(Y_true, Y_pred)class Evaluator(keras.callbacks.Callback):"""评估与保存"""def __init__(self):self.best_val_score = 0.def on_epoch_end(self, epoch, logs=None):val_score = evaluate(valid_generator)if val_score > self.best_val_score:self.best_val_score = val_scoremodel.save_weights('best_model.weights')print(u'val_score: %.5f, best_val_score: %.5f\n' %(val_score, self.best_val_score))def predict_to_file(out_file):"""预测结果到文件"""F = open(out_file, 'w')for x_true, _ in tqdm(test_generator):y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)for p in y_pred:F.write('%f\n' % p)F.close()if __name__ == '__main__':evaluator = Evaluator()model.fit(train_generator.forfit(),steps_per_epoch=len(train_generator),epochs=100,callbacks=[evaluator])else:model.load_weights('best_model.weights')

加载数据以及评估等,最后在主函数中调用。

天池oppo-text-match比赛-苏剑林baseline代码解读相关推荐

  1. 相似度衡量: 苏剑林博客-1

    测地线距离(Geodesic Distance):地表上两点之间的最短路径的距离. 如下图所示,在二维空间中,两个黑点之间的欧式距离是虚线的长度,测地线距离时d12+d23+d34+d45的距离之和. ...

  2. 学会提问的BERT:端到端地从篇章中构建问答对 By 苏剑林

    机器阅读理解任务,相比不少读者都有所了解了,简单来说就是从给定篇章中寻找给定问题的答案,即"篇章 + 问题 → 答案"这样的流程,笔者之前也写过一些关于阅读理解的文章,比如< ...

  3. 苏建林DGCNN模型代码详解

    1.说明 以下代码为苏神参加百度信息抽取比赛提出的DGCNN模型代码,其源码基本是没有注释的,所以本文对大部分代码做了注释,可能有不对的地方,可以指出.另一个就是对keras3.x版本下Attenti ...

  4. 【苏小林主页】基于TP6和光年模板的个人主页

    基于光年V4和thinkphp的个人主页 开源仓库地址:GitHub - suxaiolin/personal: 个人主页(带后台) - su personal 直链下载地址:苏小林个人主页带后台管理 ...

  5. 胡适曾劝女作家苏雪林不要骂鲁迅

    2006.8.26 16:56 "五四"运动后期,随着<新青年>杂志的分裂,胡适与鲁迅日渐分道扬镳,走进了不同营垒.胡适的"多研究问题,少谈些主义" ...

  6. 计算机教育格言,苏霍姆林斯基教育名言大全

    苏霍姆林斯基教育名言大全 1.要想自己成为幸福的人,就应当对别人关怀备至,体贴入微,赤诚相见. 2.友谊是培养人的感情的学校. 3.友谊是培养人的感情的学校.我们所以需要友谊,并不是想用它打发时间,而 ...

  7. 百行代码解读阿里 AloT 芯片平台无剑 100!

    作者 | 马超 责编 | 胡巍巍 出品 | CSDN(ID:CSDNnews)  今年以来我国IT厂商都在AIot的底层平台建设方面可谓是捷报频传,在操作系统方面有如像腾讯的Tiny OS.阿里的Al ...

  8. 百度图神经网络——论文节点比赛baseline代码注解

    文章目录 一.项目介绍 二.BaseLine内容注解 1.Config部分注解 2. 数据读取与处理部分 2.1 边数据的加载与处理 2.2 数据的完整加载与处理 2.3 数据读取与分割 3. 模型加 ...

  9. 风林评《解读基金,我的投资观与实践》|你怎么能不知道基金投资的三个思想...

        风林评<解读基金,我的投资观与实践>|你怎么能不知道基金投资的三个思想 2019股市的红红火火必然影响着基金的净值的变化.如果对股市的风险承受能力不足的话,可以到股票的二级市场基金 ...

最新文章

  1. unity水管工_我是如何从30岁的管道工转变为32岁的Web开发人员的
  2. 刚入职,就被各种 Code Review,真的有必要吗?
  3. Android线程模型和AsyncTask
  4. 超图桌面版制作分段专题图学习
  5. 【小朋友才做选择题】跟着团队一起学习人工智能,先人一步掌握最前沿知识
  6. 汉字转拼音(c#) -转载
  7. 相同update语句在MySQL,Oracle的不同表现(r12笔记第30天)
  8. Spring MVC自定义验证注释
  9. 我以为我对Mysql索引很了解,直到我遇到了阿里的面试官
  10. [lct] Luogu P4219 大融合
  11. 斯特林数-斯特林反演
  12. Visual Studio 2015打开ASP.NET MVC的View提示“Object reference not set to an instance of an object“错误的解决方案
  13. 《Java安全编码标准》一2.9 IDS08-J净化传递给正则表达式的非受信数据
  14. 增强网络安全意识——如何5分钟破解校园网上网账号和密码
  15. 将Android手机无线连接到Ubuntu实现唱跳Rap
  16. html directive 内容传递,directive的传值问题(全面解析directive的传值问题)微信分享实例...
  17. javascript常见的设计模式
  18. [转]倾斜摄影单体化实现方案
  19. 今天是植树节,你“植树”了吗?
  20. 真的不能错过的打印攻略!打印一张7分钱

热门文章

  1. 自动批量裁剪+合成+整理正反向序列
  2. “大数据”这个名字叫错了,今天数据的意义并不在于有多“大”,真正有意思的是数据变得在线了。...
  3. uniapp实现上拉刷新,下拉加载
  4. kaggle心脏病监测分析案例(数据分析+数据可视化)适合入门新手
  5. 计算机英语在线学习,英语单词记忆法超强记忆_免费背单词软件电脑版
  6. 前端diff文件对比使用worker进行优化
  7. STM32-ADC模拟数字转换器
  8. mysql offset 问题_MySQL_优化mysql的limit offset的例子, 经常碰到的一个问题是limi - phpStudy...
  9. python发牌游戏图形界面_python实现扑克牌交互式界面发牌程序
  10. LeetCodeClassification---- No.1 分治--使用递归完成对一维数组的求和--递归法完成sum函数