天池oppo-text-match比赛-苏剑林baseline代码解读
本文根据苏剑林的基于bert的baseline进行短文本匹配的讲解,其github地址是:https://github.com/bojone/oppo-text-match/blob/main/baseline.py
赛题地址:
https://tianchi.aliyun.com/competition/entrance/531851
数据探索
下载好相关数据之后,我们先看一下数据是什么样的:
path = '/content/drive/MyDrive/oppo-text-match/baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
with open(path,'r',encoding='utf-8') as f:lines = f.readlines()for i,line in enumerate(lines):print(line.split('\t'))if i == 5:break
结果:
['1 2 3 4 5 6 7', '8 9 10 4 11', '0\n']
['12 13 14 15', '12 15 11 16', '0\n']
['17 18 12 19 20 21 22 23 24', '12 23 25 6 26 27 19', '1\n']
['28 29 30 31 11', '32 33 34 30 31', '1\n']
['29 35 36 29', '29 37 36 29', '1\n']
['38 23 39 9 40', '12 19 41 42 23 43 12 23 44 41 42 19', '0\n']
数据都是脱敏的,也就是字都用数字来表示了。
统计一下text1+text2的长度:
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
import pandas as pd
fromtrain_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_train_20210228.tsv'
test_path = './baseline_tfidf_lr/oppo_breeno_round1_data/gaiic_track3_round1_testA_20210228.tsv'def cal_len_dis(path):with open(path,'r',encoding='utf-8') as f:lines = f.readlines()len_list = []for line in lines:line = line.strip().split('\t')len_list.append(len(line[0]+line[1]))return len_listdef get_len_detail(data):df = pd.DataFrame(data)res = df.describe()return res# 设置matplotlib正常显示中文和负号
font = FontProperties(fname=r'/data02/gob/project/text-match/simhei.ttf')
matplotlib.rcParams['axes.unicode_minus']=False # 正常显示负号def draw_hist(data):plt.hist(data, bins=40, facecolor="blue", edgecolor="black", alpha=0.7)# 显示横轴标签plt.xlabel("长度",fontproperties=font)# 显示纵轴标签plt.ylabel("数量",fontproperties=font)# 显示图标题plt.title("句子长度统计",fontproperties=font)plt.savefig('len_hist.png')plt.show()if __name__ == '__main__':len_list = cal_len_dis(train_path)res = get_len_detail(len_list)print(res)draw_hist(len_list)
相关统计量:
count 100000.0000
mean 46.8328
std 17.317
min 12
25% 35
50% 43
75% 55
max 279
baseline中值得注意的一些代码
from bert4keras.snippets import truncate_sequences
truncate_sequences(maxlen, -1, a, b)
这个函数用于截断超过最大长度的句子,如果len(a+b)>maxlen,则对句子a进行截断。
def random_mask(text_ids):"""随机mask"""input_ids, output_ids = [], []rands = np.random.random(len(text_ids))for r, i in zip(rands, text_ids):if r < 0.15 * 0.8:input_ids.append(4)output_ids.append(i)elif r < 0.15 * 0.9:input_ids.append(i)output_ids.append(i)elif r < 0.15:input_ids.append(np.random.choice(len(tokens)) + 7)output_ids.append(i)else:input_ids.append(i)output_ids.append(0)return input_ids, output_ids
这个函数用于随机将一些词mask掉。
def sample_convert(text1, text2, label, random=False):"""转换为MLM格式"""text1_ids = [tokens.get(t, 1) for t in text1]text2_ids = [tokens.get(t, 1) for t in text2]if random:if np.random.random() < 0.5:text1_ids, text2_ids = text2_ids, text1_idstext1_ids, out1_ids = random_mask(text1_ids)text2_ids, out2_ids = random_mask(text2_ids)else:out1_ids = [0] * len(text1_ids)out2_ids = [0] * len(text2_ids)token_ids = [2] + text1_ids + [3] + text2_ids + [3]segment_ids = [0] * len(token_ids)output_ids = [label + 5] + out1_ids + [0] + out2_ids + [0]return token_ids, segment_ids, output_ids
用于转换单个样本为bert的输入。
class data_generator(DataGenerator):"""数据生成器"""def __iter__(self, random=False):batch_token_ids, batch_segment_ids, batch_output_ids = [], [], []for is_end, (text1, text2, label) in self.sample(random):token_ids, segment_ids, output_ids = sample_convert(text1, text2, label, random)batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)batch_output_ids.append(output_ids)if len(batch_token_ids) == self.batch_size or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)batch_output_ids = sequence_padding(batch_output_ids)yield [batch_token_ids, batch_segment_ids], batch_output_idsbatch_token_ids, batch_segment_ids, batch_output_ids = [], [], []
用于将多个样本制作成batch的格式。
# 加载预训练模型
model = build_transformer_model(config_path=config_path,checkpoint_path=checkpoint_path,with_mlm=True,keep_tokens=[0, 100, 101, 102, 103, 100, 100] + keep_tokens[:len(tokens)]
)
这个函数用于加载预训练的模型
def masked_crossentropy(y_true, y_pred):"""mask掉非预测部分"""y_true = K.reshape(y_true, K.shape(y_true)[:2])y_mask = K.cast(K.greater(y_true, 0.5), K.floatx())loss = K.sparse_categorical_crossentropy(y_true, y_pred)loss = K.sum(loss * y_mask) / K.sum(y_mask)return loss[None, None]
计算损失。
model.compile(loss=masked_crossentropy, optimizer=Adam(1e-5))
model.summary()
定义优化器和损失。
# 转换数据集
train_generator = data_generator(train_data, batch_size)
valid_generator = data_generator(valid_data, batch_size)
test_generator = data_generator(test_data, batch_size)def evaluate(data):"""线下评测函数"""Y_true, Y_pred = [], []for x_true, y_true in data:y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)y_true = y_true[:, 0] - 5Y_pred.extend(y_pred)Y_true.extend(y_true)return roc_auc_score(Y_true, Y_pred)class Evaluator(keras.callbacks.Callback):"""评估与保存"""def __init__(self):self.best_val_score = 0.def on_epoch_end(self, epoch, logs=None):val_score = evaluate(valid_generator)if val_score > self.best_val_score:self.best_val_score = val_scoremodel.save_weights('best_model.weights')print(u'val_score: %.5f, best_val_score: %.5f\n' %(val_score, self.best_val_score))def predict_to_file(out_file):"""预测结果到文件"""F = open(out_file, 'w')for x_true, _ in tqdm(test_generator):y_pred = model.predict(x_true)[:, 0, 5:7]y_pred = y_pred[:, 1] / (y_pred.sum(axis=1) + 1e-8)for p in y_pred:F.write('%f\n' % p)F.close()if __name__ == '__main__':evaluator = Evaluator()model.fit(train_generator.forfit(),steps_per_epoch=len(train_generator),epochs=100,callbacks=[evaluator])else:model.load_weights('best_model.weights')
加载数据以及评估等,最后在主函数中调用。
天池oppo-text-match比赛-苏剑林baseline代码解读相关推荐
- 相似度衡量: 苏剑林博客-1
测地线距离(Geodesic Distance):地表上两点之间的最短路径的距离. 如下图所示,在二维空间中,两个黑点之间的欧式距离是虚线的长度,测地线距离时d12+d23+d34+d45的距离之和. ...
- 学会提问的BERT:端到端地从篇章中构建问答对 By 苏剑林
机器阅读理解任务,相比不少读者都有所了解了,简单来说就是从给定篇章中寻找给定问题的答案,即"篇章 + 问题 → 答案"这样的流程,笔者之前也写过一些关于阅读理解的文章,比如< ...
- 苏建林DGCNN模型代码详解
1.说明 以下代码为苏神参加百度信息抽取比赛提出的DGCNN模型代码,其源码基本是没有注释的,所以本文对大部分代码做了注释,可能有不对的地方,可以指出.另一个就是对keras3.x版本下Attenti ...
- 【苏小林主页】基于TP6和光年模板的个人主页
基于光年V4和thinkphp的个人主页 开源仓库地址:GitHub - suxaiolin/personal: 个人主页(带后台) - su personal 直链下载地址:苏小林个人主页带后台管理 ...
- 胡适曾劝女作家苏雪林不要骂鲁迅
2006.8.26 16:56 "五四"运动后期,随着<新青年>杂志的分裂,胡适与鲁迅日渐分道扬镳,走进了不同营垒.胡适的"多研究问题,少谈些主义" ...
- 计算机教育格言,苏霍姆林斯基教育名言大全
苏霍姆林斯基教育名言大全 1.要想自己成为幸福的人,就应当对别人关怀备至,体贴入微,赤诚相见. 2.友谊是培养人的感情的学校. 3.友谊是培养人的感情的学校.我们所以需要友谊,并不是想用它打发时间,而 ...
- 百行代码解读阿里 AloT 芯片平台无剑 100!
作者 | 马超 责编 | 胡巍巍 出品 | CSDN(ID:CSDNnews) 今年以来我国IT厂商都在AIot的底层平台建设方面可谓是捷报频传,在操作系统方面有如像腾讯的Tiny OS.阿里的Al ...
- 百度图神经网络——论文节点比赛baseline代码注解
文章目录 一.项目介绍 二.BaseLine内容注解 1.Config部分注解 2. 数据读取与处理部分 2.1 边数据的加载与处理 2.2 数据的完整加载与处理 2.3 数据读取与分割 3. 模型加 ...
- 风林评《解读基金,我的投资观与实践》|你怎么能不知道基金投资的三个思想...
风林评<解读基金,我的投资观与实践>|你怎么能不知道基金投资的三个思想 2019股市的红红火火必然影响着基金的净值的变化.如果对股市的风险承受能力不足的话,可以到股票的二级市场基金 ...
最新文章
- unity水管工_我是如何从30岁的管道工转变为32岁的Web开发人员的
- 刚入职,就被各种 Code Review,真的有必要吗?
- Android线程模型和AsyncTask
- 超图桌面版制作分段专题图学习
- 【小朋友才做选择题】跟着团队一起学习人工智能,先人一步掌握最前沿知识
- 汉字转拼音(c#) -转载
- 相同update语句在MySQL,Oracle的不同表现(r12笔记第30天)
- Spring MVC自定义验证注释
- 我以为我对Mysql索引很了解,直到我遇到了阿里的面试官
- [lct] Luogu P4219 大融合
- 斯特林数-斯特林反演
- Visual Studio 2015打开ASP.NET MVC的View提示“Object reference not set to an instance of an object“错误的解决方案
- 《Java安全编码标准》一2.9 IDS08-J净化传递给正则表达式的非受信数据
- 增强网络安全意识——如何5分钟破解校园网上网账号和密码
- 将Android手机无线连接到Ubuntu实现唱跳Rap
- html directive 内容传递,directive的传值问题(全面解析directive的传值问题)微信分享实例...
- javascript常见的设计模式
- [转]倾斜摄影单体化实现方案
- 今天是植树节,你“植树”了吗?
- 真的不能错过的打印攻略!打印一张7分钱
热门文章
- 自动批量裁剪+合成+整理正反向序列
- “大数据”这个名字叫错了,今天数据的意义并不在于有多“大”,真正有意思的是数据变得在线了。...
- uniapp实现上拉刷新,下拉加载
- kaggle心脏病监测分析案例(数据分析+数据可视化)适合入门新手
- 计算机英语在线学习,英语单词记忆法超强记忆_免费背单词软件电脑版
- 前端diff文件对比使用worker进行优化
- STM32-ADC模拟数字转换器
- mysql offset 问题_MySQL_优化mysql的limit offset的例子, 经常碰到的一个问题是limi - phpStudy...
- python发牌游戏图形界面_python实现扑克牌交互式界面发牌程序
- LeetCodeClassification---- No.1 分治--使用递归完成对一维数组的求和--递归法完成sum函数