【coding】Bert-Whitening细解
BERT-Whitening细解
@author: Heisenberg
@date: 2021-01-16
The code was shared from Jianlin Su on his blog
And This is a repo.
Data can be download from here
原jupyter notebook 格式可在github上查看。
测试任务:GLUE的STS-B句子相似性任务
测试环境:tf2.2.0+ keras2.3.1+ bert4keras 0.9.8
对向量进行线性变换(即数据挖掘中的白化操作)可达到BERT-Flow的效果
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.snippets import open, sequence_padding
from keras.models import Model
Using TensorFlow backend.
读取并加载数据集
def load_train_data(filename):"""加载训练数据(带标签)单条格式:(文本1, 文本2, 标签)"""D = []with open(filename, encoding='utf-8') as f:for i, l in enumerate(f):if i > 0:l = l.strip().split('\t')D.append((l[-2], l[-1], float(l[-3])))return Ddef load_test_data(filename):"""加载测试数据(带标签)单条格式:(文本1, 文本2, 标签)"""D = []with open(filename, encoding='utf-8') as f:for l in f:l = l.strip().split('\t')D.append((l[-2], l[-1], float(l[-3])))return D# 加载数据集
train_data = load_train_data('STS-B/original/sts-train.tsv')
test_data = load_test_data('STS-B/original/sts-dev.tsv')
#(文本1, 文本2, 相似性分数x(/5.0))
print(train_data[0])
print(test_data[0])
('A man is playing a large flute.', 'A man is playing a flute.', 3.8)
('A man with a hard hat is dancing.', 'A man wearing a hard hat is dancing.', 5.0)
加载Bert与分词器
config_path = 'bert/uncased_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'bert/uncased_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'bert/uncased_L-12_H-768_A-12/vocab.txt'tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
自定义全局池化
class GlobalAveragePooling1D(keras.layers.GlobalAveragePooling1D):"""自定义全局池化"""def call(self, inputs, mask=None):if mask is not None:mask = K.cast(mask, K.floatx())[:, :, None]return K.sum(inputs * mask, axis=1) / K.sum(mask, axis=1)else:return K.mean(inputs, axis=1)
建立模型
bert = build_transformer_model(config_path, checkpoint_path)encoder_layers, count = [], 0
while True:try:output = bert.get_layer('Transformer-%d-FeedForward-Norm' % count).outputencoder_layers.append(output)count += 1except:breakn_last, outputs = 2, []
for i in range(n_last):outputs.append(GlobalAveragePooling1D()(encoder_layers[-i]))output = keras.layers.Average()(outputs)
最后的编码器
encoder = Model(bert.inputs, output)
转换文本数据为id形式
def convert_to_vecs(data, maxlen=64):"""转换文本数据为id形式"""a_token_ids, b_token_ids, labels = [], [], []for d in data:token_ids = tokenizer.encode(d[0], maxlen=maxlen)[0]a_token_ids.append(token_ids)token_ids = tokenizer.encode(d[1], maxlen=maxlen)[0]b_token_ids.append(token_ids)labels.append(d[2])a_token_ids = sequence_padding(a_token_ids)b_token_ids = sequence_padding(b_token_ids)a_vecs = encoder.predict([a_token_ids,np.zeros_like(a_token_ids)],verbose=True)b_vecs = encoder.predict([b_token_ids,np.zeros_like(b_token_ids)],verbose=True)return a_vecs, b_vecs, np.array(labels)
计算kernel和bias
def compute_kernel_bias(vecs):"""计算kernel和bias最后的变换:y = (x + bias).dot(kernel)"""vecs = np.concatenate(vecs, axis=0)mu = vecs.mean(axis=0, keepdims=True)cov = np.cov(vecs.T)u, s, vh = np.linalg.svd(cov)W_inv = np.dot(u, np.diag(s**0.5))W = np.linalg.inv(W_inv.T)return W, -mu
变换及标准化
def transform_and_normalize(vecs, kernel=None, bias=None):"""应用变换,然后标准化"""if not (kernel is None or bias is None):vecs = (vecs + bias).dot(kernel)return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
语料向量化,计算变换矩阵和偏置项
a_train_vecs, b_train_vecs, train_labels = convert_to_vecs(train_data)
5551/5551 [==============================] - 41s 7ms/step
5551/5551 [==============================] - 39s 7ms/step
a_test_vecs, b_test_vecs, test_labels = convert_to_vecs(test_data)
1478/1478 [==============================] - 8s 5ms/step
1478/1478 [==============================] - 9s 6ms/step
#将训练集和测试集的句子都转化为768维的向量
print(len(a_train_vecs[0]))
print(len(a_test_vecs[0]))
768
768
kernel, bias = compute_kernel_bias([a_train_vecs, b_train_vecs, a_test_vecs, b_test_vecs])
kernel bias计算详解
all_vecs = np.concatenate([a_train_vecs, b_train_vecs, a_test_vecs, b_test_vecs], axis=0)
# all_vecs是(5551+5551+1478+1478)× 768维的矩阵
all_vecs.shape
(14058, 768)
关于768列的每个维度的均值向量 μ \mu μ
mu = all_vecs.mean(axis=0, keepdims=True)
mu.shape
(1, 768)
768列之间的相关系数矩阵 Σ \Sigma Σ
cov = np.cov(all_vecs.T)
cov.shape
(768, 768)
根据SVD分解 Σ = U Λ U ⊤ \Sigma=U\Lambda U^{\top} Σ=UΛU⊤及 Σ = ( W − 1 ) ⊤ W − 1 \Sigma=(W^{-1})^{\top}W^{-1} Σ=(W−1)⊤W−1计算得 W − 1 = U Λ W^{-1}=U\sqrt{\Lambda} W−1=UΛ
u, s, vh = np.linalg.svd(cov)
print(u.shape)
print(s.shape)
print(vh.shape)
(768, 768)
(768,)
(768, 768)
求解$ W^{-1}$
W_inv = np.dot(u, np.diag(s**0.5))
W_inv.shape
(768, 768)
求解 W W W
W = np.linalg.inv(W_inv.T)
W.shape
(768, 768)
变换及标准化详解
a_train_vecs = transform_and_normalize(a_train_vecs, kernel, bias)
b_train_vecs = transform_and_normalize(b_train_vecs, kernel, bias)
按照 x i ~ = ( x i − μ ) W \tilde{x_i}=(x_i-\mu)W xi~=(xi−μ)W进行变换
all_vecs = (all_vecs-mu).dot(W)
all_vecs.shape
(14058, 768)
按照 x i ′ = x i Σ x i 2 x_i^{\prime}=\frac{x_i}{\sqrt{\Sigma x_i^2}} xi′=Σxi2 xi进行数据标准化
all_vecs = all_vecs/(all_vecs**2).sum(axis=1,keepdims=True)**0.5
all_vecs.shape
(14058, 768)
计算训练集中句子相关分数
np.corrcoef(train_labels, train_sims)
array([[1. , 0.71206009],[0.71206009, 1. ]])
train_sims = (a_train_vecs * b_train_vecs).sum(axis=1)
print(u'训练集的相关系数:%s' % np.corrcoef(train_labels, train_sims)[0, 1])
训练集的相关系数:0.7120600911305668
关于np.corrcorf(X,Y):
返回相关系数矩阵
R ( x , y ) = C o v ( X , Y ) V a r ( X ) V a r ( Y ) = E ( X Y ) − E ( X ) E ( Y ) V a r ( X ) V a r ( Y ) R(x,y)=\frac{Cov(X,Y)}{\sqrt{Var(X)Var(Y)}}=\frac{E(XY)-E(X)E(Y)}{\sqrt{Var(X)Var(Y)}} R(x,y)=Var(X)Var(Y) Cov(X,Y)=Var(X)Var(Y) E(XY)−E(X)E(Y)
#[0,1]意为取第0行第1列的数
np.corrcoef([1,3,5],[2,4,5])[0,1]
0.9819805060619656
r = 13 − 3 ∗ 11 3 1.633 ∗ 1.247 = 0.98198 r = \frac{13-3*\frac{11}{3}}{1.633*1.247}=0.98198 r=1.633∗1.24713−3∗311=0.98198
计算测试集中句子相关分数
a_test_vecs = transform_and_normalize(a_test_vecs, kernel, bias)
b_test_vecs = transform_and_normalize(b_test_vecs, kernel, bias)
test_sims = (a_test_vecs * b_test_vecs).sum(axis=1)
print(u'测试集的相关系数:%s' % np.corrcoef(test_labels, test_sims)[0, 1])
测试集的相关系数:0.7745647933327217
【coding】Bert-Whitening细解相关推荐
- Silverlight实用窍门系列:35.细解Silverlight冒泡路由事件和注册冒泡路由事件【附带实例源码】...
Silverlight中的事件分为普通事件和冒泡路由事件,它并没有包括WPF中的隧道路由事件,在本章中将详细讲解冒泡路由事件和如何注册一个冒泡路由事件. 一.细解冒泡路由事件 冒泡路由事件可以比喻为: ...
- web前端细解cookie那些事
web前端细解cookie那些事,在互联网时代,IT行业飞速发展,带动了web前端开发行业的兴趣.由于行业新兴起时间不久,专业人才缺乏,薪资待遇较高,已成为众多IT学子选择就业的首选,今天就为分享一些 ...
- WCF从理论到实践(5):Binding细解(转)
WCF从理论到实践(5):Binding细解 本文的出发点: 通过阅读本文,您能了解以下知识: WCF中的Binding是什么? Binding的组成? Binding Element 的分类? Bi ...
- 前端flv.js设置缓冲时间和大小_好程序员web前端细解cookie那些事
好程序员web前端细解cookie那些事,在互联网时代,IT行业飞速发展,带动了web前端开发行业的兴趣.由于行业新兴起时间不久,专业人才缺乏,薪资待遇较高,已成为众多IT学子选择就业的首选,今天就为 ...
- videoleap自带素材_videoleap教程:制作电影帷幕开场效果细解
大家晚上好,我是Mr.吴 劳模吴又熬夜来给大家更新教程了 今天给大家带来的教程是 --如何制作电影开场的效果 这里我先放一个上周末出去约拍 记录的一个vlog成品杭州印打卡--WABF小分队https ...
- BERT embedding 降维--BERT whitening
利用BERT whitening可以将embedding 比如768维降到256维 def compute_kernel_bias(vecs, n_components=256):"&quo ...
- HijackThis日志细解【简明教程增强版】(一)
转的贴(偶是怕以后看不到了,所以保存下来的),原文章(By 风之咏者)地址:http://bbs.kingsoft.com/viewthread.php?tid=407983&sid=8miH ...
- HijackThis日志细解--清净网络(复杂详尽)
一.说在前面的提示(请原谅我啰嗦) 提示一:本文目的 本文的目的是帮助您进一步解读HijackThis扫描日志.如果您只是想知道HijackThis的使用方法,下面列出的2篇文章可以满足您的要求: 1 ...
- 揪出狐狸的尾巴,HijackThis日志细解【附反劫持一般建议】
HijackThis日志细解[附反劫持一般建议] 一.说在前面的提示(请原谅我啰嗦) 提示一:本文目的 本文的目的是帮助您进一步解读HijackThis扫描日志.如果您只是想知道HijackThis的 ...
最新文章
- eclipse搭建 tomcat、
- Redis实现消息队列的4种方案
- mysql中ifnull函数
- 电商系统的商品规格设计方案
- [Java]==和equals()的区别(按照数据类型区分)
- sql实现寻找中位数(使用sign、case、自定义变量等)
- Opengl编程指南第二章:状态管理、几何绘图
- 理解_RBAC基础概念_Spring Security OAuth2.0认证授权---springcloud工作笔记113
- Dubbo实战快速入门 (转)
- 根据已有的WSDL文件进行WebService服务开发和部署
- (最通俗易懂的)目标跟踪MOSSE、KCF
- 尚硅谷前端视频总结(一)
- WEB前端开发快速入门教程
- 主动轮廓模型:Snake模型的python实现
- 举个栗子~Tableau 技巧(205 ):区域地图中呈现具体位置
- scrapy 搜索关键字_基于scrapy框架输入关键字爬取有关贴吧帖子
- 使用log4j失误导致系统假死,记录一下
- linux bond四网卡绑定,Linux bond 网卡绑定配置教程
- TCP 握手没成功怎么办?
- Excel序号删除某行之后不连贯?这样做可以智能更新表格序号!