接着上一篇(https://blog.csdn.net/jasmine0244/article/details/108888236)

设置好参数:

max_q_len = 80
max_a_len = 80

接下来就是构建模型部分了。

首先加载词库和分词器,

# 加载并精简词表,建立分词器
token_dict, keep_tokens = load_vocab(dict_path=dict_path,simplified=True,startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)

然后是数据生成器

class data_generator(DataGenerator):"""数据生成器"""def __iter__(self, random=False):"""单条样本格式:[CLS]篇章[SEP]答案[SEP]问题[SEP]"""batch_token_ids, batch_segment_ids = [], []for is_end, (p, q, a) in self.sample(random):p_token_ids, _ = tokenizer.encode(p, maxlen=max_p_len + 1)a_token_ids, _ = tokenizer.encode(a, maxlen=max_a_len)q_token_ids, _ = tokenizer.encode(q, maxlen=max_q_len)token_ids = p_token_ids + a_token_ids[1:] + q_token_ids[1:]segment_ids = [0] * len(p_token_ids)segment_ids += [1] * (len(token_ids) - len(p_token_ids))batch_token_ids.append(token_ids)batch_segment_ids.append(segment_ids)if len(batch_token_ids) == self.batch_size or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = sequence_padding(batch_segment_ids)yield [batch_token_ids, batch_segment_ids], Nonebatch_token_ids, batch_segment_ids = [], []

再然后是loss部分

class CrossEntropy(Loss):"""交叉熵作为loss,并mask掉输入部分"""def compute_loss(self, inputs, mask=None):y_true, y_mask, y_pred = inputsy_true = y_true[:, 1:]  # 目标token_idsy_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分y_pred = y_pred[:, :-1]  # 预测序列,错开一位loss = K.sparse_categorical_crossentropy(y_true, y_pred)loss = K.sum(loss * y_mask) / K.sum(y_mask)

然后构建模型

def build_model():model = build_transformer_model(config_path,checkpoint_path,application='unilm',keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表)output = CrossEntropy(2)(model.inputs + model.outputs)model = Model(model.inputs, output)model.compile(optimizer=Adam(1e-5))return model

生成答案选用的是beam search

class QuestionAnswerGeneration(AutoRegressiveDecoder):"""通过beam search来生成问题"""@AutoRegressiveDecoder.wraps(default_rtype='probas')def predict(self, inputs, output_ids, states):token_ids, segment_ids = inputstoken_ids = np.concatenate([token_ids, output_ids], 1)segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)return model.predict([token_ids, segment_ids])[:, -1]def generate(self, passage, ans,topk=5):token_ids, segment_ids = tokenizer.encode(passage, maxlen=max_p_len)a_ids = tokenizer.encode(ans)[0]token_ids += list(a_ids)segment_ids += [1] * len(a_ids)q_ids = self.beam_search([token_ids, segment_ids],topk)  # 基于beam searchreturn tokenizer.decode(q_ids)

callback回调函数

class Evaluator(keras.callbacks.Callback):def __init__(self,mode):self.lowest = 1e10self.mode = modedef on_epoch_end(self, epoch, logs=None):# 保存最优if logs['loss'] <= self.lowest:self.lowest = logs['loss']

将预测结果输出到文本,用于评估

def predict_to_file(data, filename, qag,topk=1):"""将预测结果输出到文件,方便评估"""with open(filename, 'w', encoding='utf-8') as f:for d in tqdm(iter(data), desc=u'正在预测(共%s条样本)' % len(data)):q = qag.generate(d[0], d[2])s = '%s\t%s\t%s\n' % (q, d[2], d[0])f.write(s)f.flush()

最后就是K折训练

# 交叉训练
for mode in range(k_folds): train_data = [data[j] for i, j in enumerate(random_order) if i % k_folds != mode]valid_data = [data[j] for i, j in enumerate(random_order) if i % k_folds == mode]evaluator = Evaluator(mode)train_generator = data_generator(train_data, batch_size)model = build_model()model.fit_generator(train_generator.forfit(),steps_per_epoch=1000,epochs=epochs,callbacks=[evaluator])qag = QuestionAnswerGeneration(start_id=None, end_id=tokenizer._token_end_id, maxlen=max_q_len)predict_to_file(valid_data, 'qa_{}.csv'.format(mode), qag)del model

交叉训练完毕后,再来评估结果,选用最后的模型来预测test,这个等到下一篇继续。

中医药天池大数据竞赛——中医文献问题生成挑战(二)相关推荐

  1. AI比赛-NER:“万创杯”中医药天池大数据竞赛——中药说明书实体识别挑战

    大赛概况 疫情催化下,人工智能正在持续助力中医药传承创新加速发展,其中中医用药知识体系沉淀挖掘是一个基础工作.通过挖掘中药说明书构建中药合理用药的知识图谱,将为中医规范诊疗奠定较好基础.挑战旨在通过抽 ...

  2. 阿里巴巴天池大数据竞赛黄金联赛全面开战,全球同步报名,只为寻找最聪明的你!...

    阿里巴巴天池大数据竞赛黄金联赛全面开战,全球同步报名,只为寻找最聪明的你!          天池大数据竞赛是由阿里巴巴集团主办,面向全球新生代力量的高端算法竞赛.通过开放海量数据和"天池& ...

  3. 【TIANCHI】天池大数据竞赛(学习赛)--- 淘宝用户购物行为数据可视化分析

    目录 前言 一.数据集的来源和各个字段的意义 二.数据分析 1.引入库 2.读入数据 3.查看数据数量级 4.PV(Page View)/UV访问量 5.漏斗模型 6.用户购买商品的频次分析. 7.A ...

  4. 天池大数据竞赛第一名,上海交通大学人工智能实验室如何用AI定位肺结节

    癌症,犹如黑暗中的魔鬼,带给人们恐惧与绝望.而肺癌,在我国作为发病率.死亡率最高的一类癌症,伤害着无数家庭.在我国每年都有近60万人死于肺癌.然而,癌症的死亡率与首次发现癌症的时期紧密相关,早期肺结节 ...

  5. 阿里天池大数据竞赛——口碑商家客流量预测 A2

    阿里天池大赛koubeiyuce1 2017年二月份,天池大数据比赛,口碑商家客流量预测,参赛地址及详情: https://tianchi.shuju.aliyun.com/competition/i ...

  6. 9个比赛7进top10,阿里天池大数据竞赛思路分享

    向AI转型的程序员都关注了这个号

  7. 阿里天池大数据竞赛第一名,如何用AI检测肺癌

    向AI转型的程序员都关注了这个号

  8. 阿里天池大数据竞赛(一)用ODPS提取特征

    //2015年4月30日 提取特征之所以难,是因为我们写出的sql语句往往只能提取一个特征. 而在ODPS上却不一样,一条语句可以提取很多特征. 如提取最近1.2.3.4天四种行为的数量 重点是dec ...

  9. 【数据竞赛】2020年11月国内大数据竞赛信息-奖池5000万

    2020年11月:下面是截止到2020年11月国内还在进行中的大数据比赛题目,非常丰富,大家选择性参加,初学者可以作为入门练手,大佬收割奖金,平时项目不多的,比赛是知识非常好的实践方式,本号会定期发布 ...

  10. ​阿里云天池工业AI大赛暨中国“印象盐城,数创未来”大数据竞赛正式启动

    记者 | 张俊潇 官网 | www.datayuan.cn 微信公众号ID | datayuancn 10月11日,"2017杭州·云栖大会"在万众期待中盛大召开,会上马云宣布组建 ...

最新文章

  1. 迟语寒:组队学习的那些事
  2. 全网最火的Nacos源码构建,你找不到第二个有我仔细的!!
  3. 机器学习预测地震的未来
  4. 单个字段去重并保留其他字段值
  5. 面试官:为什么 wait() 方法需要写在循环里?
  6. datatable 创建列赋值_Datatable 添加新列并赋值
  7. 一个单片机的小问题。
  8. flyme8会更新Android版本吗,魅族17系列升级Flyme 8.1操作系统:终于到Android 10
  9. 3. Web Dynpro for ABAP: Web Dynpro Window Web Dynpro Program
  10. 【codevs1052】地鼠游戏
  11. Java词汇表(三)L——O
  12. 机器学习—线性回归推广及案例
  13. MySql常用函数大全
  14. python打造最强表白程序,Python 打造七夕最强表白程序
  15. java 埋点_数据采集之js埋点
  16. kafka数据同步Elasticsearch深入详解
  17. 强烈推荐这款刷题小程序
  18. postgreSQL数据类型字符串和数值相互转换
  19. OpenGL第十讲——像素图
  20. 家庭收支记账软件--Java

热门文章

  1. powerdesign如何导出数据库到mysql数据库
  2. html背景半透明 字不变,css实现背景半透明文字不透明的效果示例
  3. 使用讯飞tts实现安卓语音中文合成播报
  4. [SDOI2016]征途
  5. 备战2022软考网络管理员(1)介绍与开篇
  6. 网络带宽压力测试教程
  7. SQL Server 搭建Northwind详细教程
  8. et200sp系统服务器模块,西门子ET200
  9. 计算机专业答辩 ppt模板 免费,计算机毕业论文答辩(完整版).ppt
  10. sax解析xml详解