BERT句向量GPU线上调用等。出现Floating point exception and SystemError: error return without exception set 。

最近上线需要用到bert,走过了很多坑,有的甚至是不知道怎么回事,而且也很容易从解决一个问题,跳到另外一个问题,巨坑呀有木有。https://github.com/hanxiao/bert-as-service这种做成服务的,其实还是挺好的,但对做成服务的,完全无感呀。

又比如这种https://github.com/terrifyzhao/bert-utils,生成的句向量和相似度计算可调用的。但是不知道是不是yield、队列queue或者gpu、cuda、cudnn的问题,Linux的GPU上有时候会报: Floating point exception.   win10和linux上debug会报: SystemError: error return without exception set 。不太敢用呀。

一.方案Keras+修改(项目地址在https://github.com/yongzhuo/nlp_xiaojiang/tree/master/FeatureProject/bert):

左思右想,只能默默地上线我一直不太爱用地keras版本了。谁让google的tensorflow也这么做呢,趋势也去迎合迎合吧。keras版本的bert和gpt-2,https://github.com/CyberZHG/keras-bert这个项目其实还很不错啦。

不说废话,直接上代码:

二、代码:

其实这种直接调用google训练好模型的,不微调的,简单的cpu也可以调用,还不费多少内存,就是速度慢些。

2.1   首先是模型,google预训练好的模型你得下载吧,可以去官方地址下,也可以来我这里前往链接: https://pan.baidu.com/s/1I3vydhmFEQ9nuPG2fDou8Q 提取码: rket

2.2    然后是主要的代码,extract_keras_bert_feature.py

# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time     :2019/5/8 20:04
# @author   :Mo
# @function :extract feature of bert and kerasimport codecs
import osimport keras.backend.tensorflow_backend as ktf_keras
import numpy as np
import tensorflow as tf
from keras.layers import Add
from keras.models import Model
from keras_bert import load_trained_model_from_checkpoint, Tokenizerfrom FeatureProject.bert.layers_keras import NonMaskingLayer
from conf.feature_config import gpu_memory_fraction, config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes# 全局使用,使其可以django、flask、tornado等调用
graph = None
model = None# gpu配置与使用率设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
sess = tf.Session(config=config)
ktf_keras.set_session(sess)class KerasBertVector():def __init__(self):self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len# 全局使用,使其可以django、flask、tornado等调用global graphgraph = tf.get_default_graph()global modelmodel = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,seq_len=self.max_seq_len)print(model.output)print(len(model.layers))# lay = model.layers#一共104个layer,其中前八层包括token,pos,embed等,# 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization)# 一共24层layer_dict = [7]layer_0 = 7for i in range(12):layer_0 = layer_0 + 4layer_dict.append(layer_0)# 输出它本身if len(layer_indexes) == 0:encoder_layer = model.output# 分类如果只有一层,就只取最后那一层的weight,取得不正确elif len(layer_indexes) == 1:if layer_indexes[0] in [i+1 for i in range(12)]:encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).outputelse:encoder_layer = model.get_layer(index=layer_dict[-2]).output# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数else:# layer_indexes must be [1,2,3,......12...24]# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(12)]else model.get_layer(index=layer_dict[-1]).output  #如果给出不正确,就默认输出最后一层for lay in layer_indexes]print(layer_indexes)print(all_layers)# 其中layer==1的output是格式不对,第二层输入input是listall_layers_select = []for all_layers_one in all_layers:all_layers_select.append(all_layers_one)encoder_layer = Add()(all_layers_select)print(encoder_layer.shape)print("KerasBertEmbedding:")print(encoder_layer.shape)output_layer = NonMaskingLayer()(encoder_layer)model = Model(model.inputs, output_layer)# model.summary(120)# reader tokenizerself.token_dict = {}with codecs.open(self.dict_path, 'r', 'utf8') as reader:for line in reader:token = line.strip()self.token_dict[token] = len(self.token_dict)self.tokenizer = Tokenizer(self.token_dict)def bert_encode(self, texts):# 文本预处理input_ids = []input_masks = []input_type_ids = []for text in texts:print(text)tokens_text = self.tokenizer.tokenize(text)print('Tokens:', tokens_text)input_id, input_type_id = self.tokenizer.encode(first=text, max_len=self.max_seq_len)input_mask = [0 if ids == 0 else 1 for ids in input_id]input_ids.append(input_id)input_type_ids.append(input_type_id)input_masks.append(input_mask)input_ids = np.array(input_ids)input_masks = np.array(input_masks)input_type_ids = np.array(input_type_ids)# 全局使用,使其可以django、flask、tornado等调用with graph.as_default():predicts = model.predict([input_ids, input_type_ids], batch_size=1)print(predicts.shape)for i, token in enumerate(tokens_text):print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())# 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.pymul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)pools = []for i in range(len(predicts)):pred = predicts[i]masks = input_masks.tolist()mask_np = np.array([masks[i]])pooled = masked_reduce_mean(pred, mask_np)pooled = pooled.tolist()pools.append(pooled[0])print('bert:', pools)return poolsif __name__ == "__main__":bert_vector = KerasBertVector()pooled = bert_vector.bert_encode(['你是谁呀', '小老弟'])print(pooled)while True:print("input:")ques = input()print(bert_vector.bert_encode([ques]))

        2.3    再就是layers_keras.py

            

# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time     :2019/5/10 10:49
# @author   :Mo
# @function :create model of keras-bert for get [-2] layersfrom keras.engine import Layerclass NonMaskingLayer(Layer):"""fix convolutional 1D can't receive masked input, detail: https://github.com/keras-team/keras/issues/4978thanks for https://github.com/jacoxu"""def __init__(self, **kwargs):self.supports_masking = Truesuper(NonMaskingLayer, self).__init__(**kwargs)def build(self, input_shape):passdef compute_mask(self, input, input_mask=None):# do not pass the mask to the next layersreturn Nonedef call(self, x, mask=None):return xdef compute_output_shape(self, input_shape):return input_shape

2.4    最后是配置文件

# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time     :2019/5/10 9:13
# @author   :Mo
# @function :path of FeatureProjectimport os# path of BERT model
file_path = os.path.dirname(__file__)
file_path = file_path.replace('conf', '') + 'Data'
model_dir = os.path.join(file_path, 'chinese_L-12_H-768_A-12/')
config_name = os.path.join(model_dir, 'bert_config.json')
ckpt_name = os.path.join(model_dir, 'bert_model.ckpt')
vocab_file = os.path.join(model_dir, 'vocab.txt')
# gpu使用率
gpu_memory_fraction = 0.2
# 默认取倒数第二层的输出值作为句向量
layer_indexes = [-2]
# 序列的最大程度,单文本建议把该值调小
max_seq_len = 26

希望对你有所帮助!

       

bert中文短文本句向量生成、相似度计算(GPU版、windows、win10、linux、django和flask可用)相关推荐

  1. snownlp 中文文本情感分析、相似度计算、分词等

    snownlp 官网:https://pypi.org/project/snownlp/ SnowNLP是一个python写的类库,可以方便的处理中文文本内容,是受到了[TextBlob](https ...

  2. gensim词向量Word2Vec安装及《庆余年》中文短文本相似度计算 | CSDN博文精选

    作者 | Eastmount 来源 | CSDN博文精选 (*点击阅读原文,查看作者更多精彩文章) 本篇文章将分享gensim词向量Word2Vec安装.基础用法,并实现<庆余年>中文短文 ...

  3. 在线部分:werobot服务、主要逻辑服务、句子相关模型服务、BERT中文预训练模型+微调模型(目的:比较两句话text1和text2之间是否有关联)、模型在Flask部署

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 智能对话系统:Unit对话API 在线聊天的总体架构与工具介绍 ...

  4. [Python人工智能] 九.gensim词向量Word2Vec安装及《庆余年》中文短文本相似度计算

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇详细讲解了卷积神经网络CNN原理,并通过TensorFlow编写CNN实现了MNIST分类学习案例.本篇文章将分享 ...

  5. 使用BERT做中文文本相似度计算与文本分类

    转载请注明出处,原文地址: https://terrifyzhao.github.io/2018/11/29/使用BERT做中文文本相似度计算.html 简介 最近Google推出了NLP大杀器BER ...

  6. ccks2020中文短文本实体链接任务测评论文--小米团队--第一名

    测评论文名:面向中文短文本的多因子融合实体链指研究 官网文档链接:http://sigkg.cn/ccks2020/?page_id=700 本笔记主要将测评论文中的主要内容提炼,方便后续借鉴,读者可 ...

  7. 【论文翻译】2020.8 清华大学AI课题组——大型中文短文本对话数据集(A Large-Scale Chinese Short-Text Conversation Dataset)

    大型中文短文本对话数据集 写在前面: 研究用,原创翻译,转载请标明出处:第一次译文,之后会跟进完善.侵删.   今年暑假末,清华大学公开了大型对话数据集及预训练模型.该数据集融合各大社交媒体对话数据库 ...

  8. 新闻上的文本分类:机器学习大乱斗 王岳王院长 王岳王院长 5 个月前 目标 从头开始实践中文短文本分类,记录一下实验流程与遇到的坑 运用多种机器学习(深度学习 + 传统机器学习)方法比较短文本分类处

    新闻上的文本分类:机器学习大乱斗 王岳王院长 5 个月前 目标 从头开始实践中文短文本分类,记录一下实验流程与遇到的坑 运用多种机器学习(深度学习 + 传统机器学习)方法比较短文本分类处理过程与结果差 ...

  9. LSF-SCNN:一种基于 CNN 的短文本表达模型及相似度计算的全新优化模型

    欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 本篇文章是我在读期间,对自然语言处理中的文本相似度问题研究取得的一点小成果.如果你对自然语言处理 (natural language proc ...

最新文章

  1. office不能安装问题
  2. JFreeChart 1.0.6 用户开发指南(中文)
  3. springboot socket服务端_从零开始学SpringBoot之Spring Boot WebSocket:编码分析
  4. Oracle 11g RAC features
  5. java 多模块项目 包路径冲突_多智能体仿真建模在交通中的应用|MATSim入门指南...
  6. web框架和后台开发_Web开发框架–第1部分:选项和标准
  7. linux下web压力测试工具ab使用及详解
  8. mysql stdistance_postgis的geography_columns和geometry_columns有什么区别
  9. 作者:郑理,男,南京邮电大学计算机学院项目研究员。
  10. mysql进阶-02-事务的引入与基本的使用
  11. java中的装箱及拆箱
  12. 6实验心得_看县委书记如何写“水平高”“亮点足”的考察心得体会!
  13. 便携式CAN分析仪与毫米波雷达搭配使用
  14. 文件上传和下载的常用测试点
  15. 负反馈放大电路实验报告
  16. Excel表格导入CAD后,表格内数字后的小数点怎么消除呢?
  17. 操作系统中多生产者多消费者问题中,关于生产者或消费者中的两个P操作是否可以互换问题
  18. python的数据拼接和融合
  19. java循环嵌套语句示范_java的三种循环结构与循环嵌套
  20. mysql SQL命令查看Mysql数据库磁盘使用量

热门文章

  1. fluent二维叶型仿真_即将直播:虎门大桥异常抖动原因查明!流固耦合仿真与工程应用直播(5月21日)...
  2. linux恢复硬盘数据
  3. NLP_1:语法树和N_gram模型
  4. poi-tl实现自定义RenderPolicy实现对表格进行列表
  5. Marathon主要功能介绍(一)
  6. df -h执行卡住不动问题解决
  7. 使用Selenium爬取豆瓣电影前100的爱情片相关信息
  8. java web pdf 打印预览_java原装代码完成pdf在线预览和pdf打印及下载
  9. 如何拷贝VCD里面加密隐藏的文件
  10. 用 Python 验证股神巴菲特的投资经验