bert中文短文本句向量生成、相似度计算(GPU版、windows、win10、linux、django和flask可用)
BERT句向量GPU线上调用等。出现Floating point exception and SystemError: error return without exception set 。
最近上线需要用到bert,走过了很多坑,有的甚至是不知道怎么回事,而且也很容易从解决一个问题,跳到另外一个问题,巨坑呀有木有。https://github.com/hanxiao/bert-as-service这种做成服务的,其实还是挺好的,但对做成服务的,完全无感呀。
又比如这种https://github.com/terrifyzhao/bert-utils,生成的句向量和相似度计算可调用的。但是不知道是不是yield、队列queue或者gpu、cuda、cudnn的问题,Linux的GPU上有时候会报: Floating point exception. win10和linux上debug会报: SystemError: error return without exception set 。不太敢用呀。
一.方案Keras+修改(项目地址在https://github.com/yongzhuo/nlp_xiaojiang/tree/master/FeatureProject/bert):
左思右想,只能默默地上线我一直不太爱用地keras版本了。谁让google的tensorflow也这么做呢,趋势也去迎合迎合吧。keras版本的bert和gpt-2,https://github.com/CyberZHG/keras-bert这个项目其实还很不错啦。
不说废话,直接上代码:
二、代码:
其实这种直接调用google训练好模型的,不微调的,简单的cpu也可以调用,还不费多少内存,就是速度慢些。
2.1 首先是模型,google预训练好的模型你得下载吧,可以去官方地址下,也可以来我这里前往链接: https://pan.baidu.com/s/1I3vydhmFEQ9nuPG2fDou8Q 提取码: rket
2.2 然后是主要的代码,extract_keras_bert_feature.py
# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/5/8 20:04
# @author :Mo
# @function :extract feature of bert and kerasimport codecs
import osimport keras.backend.tensorflow_backend as ktf_keras
import numpy as np
import tensorflow as tf
from keras.layers import Add
from keras.models import Model
from keras_bert import load_trained_model_from_checkpoint, Tokenizerfrom FeatureProject.bert.layers_keras import NonMaskingLayer
from conf.feature_config import gpu_memory_fraction, config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes# 全局使用,使其可以django、flask、tornado等调用
graph = None
model = None# gpu配置与使用率设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
sess = tf.Session(config=config)
ktf_keras.set_session(sess)class KerasBertVector():def __init__(self):self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len# 全局使用,使其可以django、flask、tornado等调用global graphgraph = tf.get_default_graph()global modelmodel = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,seq_len=self.max_seq_len)print(model.output)print(len(model.layers))# lay = model.layers#一共104个layer,其中前八层包括token,pos,embed等,# 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization)# 一共24层layer_dict = [7]layer_0 = 7for i in range(12):layer_0 = layer_0 + 4layer_dict.append(layer_0)# 输出它本身if len(layer_indexes) == 0:encoder_layer = model.output# 分类如果只有一层,就只取最后那一层的weight,取得不正确elif len(layer_indexes) == 1:if layer_indexes[0] in [i+1 for i in range(12)]:encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).outputelse:encoder_layer = model.get_layer(index=layer_dict[-2]).output# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数else:# layer_indexes must be [1,2,3,......12...24]# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(12)]else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层for lay in layer_indexes]print(layer_indexes)print(all_layers)# 其中layer==1的output是格式不对,第二层输入input是listall_layers_select = []for all_layers_one in all_layers:all_layers_select.append(all_layers_one)encoder_layer = Add()(all_layers_select)print(encoder_layer.shape)print("KerasBertEmbedding:")print(encoder_layer.shape)output_layer = NonMaskingLayer()(encoder_layer)model = Model(model.inputs, output_layer)# model.summary(120)# reader tokenizerself.token_dict = {}with codecs.open(self.dict_path, 'r', 'utf8') as reader:for line in reader:token = line.strip()self.token_dict[token] = len(self.token_dict)self.tokenizer = Tokenizer(self.token_dict)def bert_encode(self, texts):# 文本预处理input_ids = []input_masks = []input_type_ids = []for text in texts:print(text)tokens_text = self.tokenizer.tokenize(text)print('Tokens:', tokens_text)input_id, input_type_id = self.tokenizer.encode(first=text, max_len=self.max_seq_len)input_mask = [0 if ids == 0 else 1 for ids in input_id]input_ids.append(input_id)input_type_ids.append(input_type_id)input_masks.append(input_mask)input_ids = np.array(input_ids)input_masks = np.array(input_masks)input_type_ids = np.array(input_type_ids)# 全局使用,使其可以django、flask、tornado等调用with graph.as_default():predicts = model.predict([input_ids, input_type_ids], batch_size=1)print(predicts.shape)for i, token in enumerate(tokens_text):print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())# 相当于pool,采用的是https://github.com/terrifyzhao/bert-utils/blob/master/graph.pymul_mask = lambda x, m: x * np.expand_dims(m, axis=-1)masked_reduce_mean = lambda x, m: np.sum(mul_mask(x, m), axis=1) / (np.sum(m, axis=1, keepdims=True) + 1e-9)pools = []for i in range(len(predicts)):pred = predicts[i]masks = input_masks.tolist()mask_np = np.array([masks[i]])pooled = masked_reduce_mean(pred, mask_np)pooled = pooled.tolist()pools.append(pooled[0])print('bert:', pools)return poolsif __name__ == "__main__":bert_vector = KerasBertVector()pooled = bert_vector.bert_encode(['你是谁呀', '小老弟'])print(pooled)while True:print("input:")ques = input()print(bert_vector.bert_encode([ques]))
2.3 再就是layers_keras.py
# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/5/10 10:49
# @author :Mo
# @function :create model of keras-bert for get [-2] layersfrom keras.engine import Layerclass NonMaskingLayer(Layer):"""fix convolutional 1D can't receive masked input, detail: https://github.com/keras-team/keras/issues/4978thanks for https://github.com/jacoxu"""def __init__(self, **kwargs):self.supports_masking = Truesuper(NonMaskingLayer, self).__init__(**kwargs)def build(self, input_shape):passdef compute_mask(self, input, input_mask=None):# do not pass the mask to the next layersreturn Nonedef call(self, x, mask=None):return xdef compute_output_shape(self, input_shape):return input_shape
2.4 最后是配置文件
# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/5/10 9:13
# @author :Mo
# @function :path of FeatureProjectimport os# path of BERT model
file_path = os.path.dirname(__file__)
file_path = file_path.replace('conf', '') + 'Data'
model_dir = os.path.join(file_path, 'chinese_L-12_H-768_A-12/')
config_name = os.path.join(model_dir, 'bert_config.json')
ckpt_name = os.path.join(model_dir, 'bert_model.ckpt')
vocab_file = os.path.join(model_dir, 'vocab.txt')
# gpu使用率
gpu_memory_fraction = 0.2
# 默认取倒数第二层的输出值作为句向量
layer_indexes = [-2]
# 序列的最大程度,单文本建议把该值调小
max_seq_len = 26
希望对你有所帮助!
bert中文短文本句向量生成、相似度计算(GPU版、windows、win10、linux、django和flask可用)相关推荐
- snownlp 中文文本情感分析、相似度计算、分词等
snownlp 官网:https://pypi.org/project/snownlp/ SnowNLP是一个python写的类库,可以方便的处理中文文本内容,是受到了[TextBlob](https ...
- gensim词向量Word2Vec安装及《庆余年》中文短文本相似度计算 | CSDN博文精选
作者 | Eastmount 来源 | CSDN博文精选 (*点击阅读原文,查看作者更多精彩文章) 本篇文章将分享gensim词向量Word2Vec安装.基础用法,并实现<庆余年>中文短文 ...
- 在线部分:werobot服务、主要逻辑服务、句子相关模型服务、BERT中文预训练模型+微调模型(目的:比较两句话text1和text2之间是否有关联)、模型在Flask部署
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 智能对话系统:Unit对话API 在线聊天的总体架构与工具介绍 ...
- [Python人工智能] 九.gensim词向量Word2Vec安装及《庆余年》中文短文本相似度计算
从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇详细讲解了卷积神经网络CNN原理,并通过TensorFlow编写CNN实现了MNIST分类学习案例.本篇文章将分享 ...
- 使用BERT做中文文本相似度计算与文本分类
转载请注明出处,原文地址: https://terrifyzhao.github.io/2018/11/29/使用BERT做中文文本相似度计算.html 简介 最近Google推出了NLP大杀器BER ...
- ccks2020中文短文本实体链接任务测评论文--小米团队--第一名
测评论文名:面向中文短文本的多因子融合实体链指研究 官网文档链接:http://sigkg.cn/ccks2020/?page_id=700 本笔记主要将测评论文中的主要内容提炼,方便后续借鉴,读者可 ...
- 【论文翻译】2020.8 清华大学AI课题组——大型中文短文本对话数据集(A Large-Scale Chinese Short-Text Conversation Dataset)
大型中文短文本对话数据集 写在前面: 研究用,原创翻译,转载请标明出处:第一次译文,之后会跟进完善.侵删. 今年暑假末,清华大学公开了大型对话数据集及预训练模型.该数据集融合各大社交媒体对话数据库 ...
- 新闻上的文本分类:机器学习大乱斗 王岳王院长 王岳王院长 5 个月前 目标 从头开始实践中文短文本分类,记录一下实验流程与遇到的坑 运用多种机器学习(深度学习 + 传统机器学习)方法比较短文本分类处
新闻上的文本分类:机器学习大乱斗 王岳王院长 5 个月前 目标 从头开始实践中文短文本分类,记录一下实验流程与遇到的坑 运用多种机器学习(深度学习 + 传统机器学习)方法比较短文本分类处理过程与结果差 ...
- LSF-SCNN:一种基于 CNN 的短文本表达模型及相似度计算的全新优化模型
欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 本篇文章是我在读期间,对自然语言处理中的文本相似度问题研究取得的一点小成果.如果你对自然语言处理 (natural language proc ...
最新文章
- office不能安装问题
- JFreeChart 1.0.6 用户开发指南(中文)
- springboot socket服务端_从零开始学SpringBoot之Spring Boot WebSocket:编码分析
- Oracle 11g RAC features
- java 多模块项目 包路径冲突_多智能体仿真建模在交通中的应用|MATSim入门指南...
- web框架和后台开发_Web开发框架–第1部分:选项和标准
- linux下web压力测试工具ab使用及详解
- mysql stdistance_postgis的geography_columns和geometry_columns有什么区别
- 作者:郑理,男,南京邮电大学计算机学院项目研究员。
- mysql进阶-02-事务的引入与基本的使用
- java中的装箱及拆箱
- 6实验心得_看县委书记如何写“水平高”“亮点足”的考察心得体会!
- 便携式CAN分析仪与毫米波雷达搭配使用
- 文件上传和下载的常用测试点
- 负反馈放大电路实验报告
- Excel表格导入CAD后,表格内数字后的小数点怎么消除呢?
- 操作系统中多生产者多消费者问题中,关于生产者或消费者中的两个P操作是否可以互换问题
- python的数据拼接和融合
- java循环嵌套语句示范_java的三种循环结构与循环嵌套
- mysql SQL命令查看Mysql数据库磁盘使用量
热门文章
- fluent二维叶型仿真_即将直播:虎门大桥异常抖动原因查明!流固耦合仿真与工程应用直播(5月21日)...
- linux恢复硬盘数据
- NLP_1:语法树和N_gram模型
- poi-tl实现自定义RenderPolicy实现对表格进行列表
- Marathon主要功能介绍(一)
- df -h执行卡住不动问题解决
- 使用Selenium爬取豆瓣电影前100的爱情片相关信息
- java web pdf 打印预览_java原装代码完成pdf在线预览和pdf打印及下载
- 如何拷贝VCD里面加密隐藏的文件
- 用 Python 验证股神巴菲特的投资经验