BERT-Whitening细解

@author: Heisenberg

@date: 2021-01-16

The code was shared from Jianlin Su on his blog

And This is a repo.

Data can be download from here

原jupyter notebook 格式可在github上查看。

测试任务:GLUE的STS-B句子相似性任务

测试环境:tf2.2.0+ keras2.3.1+ bert4keras 0.9.8

对向量进行线性变换(即数据挖掘中的白化操作)可达到BERT-Flow的效果

import numpy as np
from bert4keras.backend import keras, K
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.snippets import open, sequence_padding
from keras.models import Model
Using TensorFlow backend.

读取并加载数据集


def load_train_data(filename):"""加载训练数据(带标签)单条格式:(文本1, 文本2, 标签)"""D = []with open(filename, encoding='utf-8') as f:for i, l in enumerate(f):if i > 0:l = l.strip().split('\t')D.append((l[-2], l[-1], float(l[-3])))return Ddef load_test_data(filename):"""加载测试数据(带标签)单条格式:(文本1, 文本2, 标签)"""D = []with open(filename, encoding='utf-8') as f:for l in f:l = l.strip().split('\t')D.append((l[-2], l[-1], float(l[-3])))return D# 加载数据集
train_data = load_train_data('STS-B/original/sts-train.tsv')
test_data = load_test_data('STS-B/original/sts-dev.tsv')
#(文本1, 文本2, 相似性分数x(/5.0))
print(train_data[0])
print(test_data[0])
('A man is playing a large flute.', 'A man is playing a flute.', 3.8)
('A man with a hard hat is dancing.', 'A man wearing a hard hat is dancing.', 5.0)

加载Bert与分词器

config_path = 'bert/uncased_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'bert/uncased_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'bert/uncased_L-12_H-768_A-12/vocab.txt'tokenizer = Tokenizer(dict_path, do_lower_case=True)  # 建立分词器

自定义全局池化

class GlobalAveragePooling1D(keras.layers.GlobalAveragePooling1D):"""自定义全局池化"""def call(self, inputs, mask=None):if mask is not None:mask = K.cast(mask, K.floatx())[:, :, None]return K.sum(inputs * mask, axis=1) / K.sum(mask, axis=1)else:return K.mean(inputs, axis=1)

建立模型

bert = build_transformer_model(config_path, checkpoint_path)encoder_layers, count = [], 0
while True:try:output = bert.get_layer('Transformer-%d-FeedForward-Norm' % count).outputencoder_layers.append(output)count += 1except:breakn_last, outputs = 2, []
for i in range(n_last):outputs.append(GlobalAveragePooling1D()(encoder_layers[-i]))output = keras.layers.Average()(outputs)

最后的编码器

encoder = Model(bert.inputs, output)

转换文本数据为id形式

def convert_to_vecs(data, maxlen=64):"""转换文本数据为id形式"""a_token_ids, b_token_ids, labels = [], [], []for d in data:token_ids = tokenizer.encode(d[0], maxlen=maxlen)[0]a_token_ids.append(token_ids)token_ids = tokenizer.encode(d[1], maxlen=maxlen)[0]b_token_ids.append(token_ids)labels.append(d[2])a_token_ids = sequence_padding(a_token_ids)b_token_ids = sequence_padding(b_token_ids)a_vecs = encoder.predict([a_token_ids,np.zeros_like(a_token_ids)],verbose=True)b_vecs = encoder.predict([b_token_ids,np.zeros_like(b_token_ids)],verbose=True)return a_vecs, b_vecs, np.array(labels)

计算kernel和bias

def compute_kernel_bias(vecs):"""计算kernel和bias最后的变换:y = (x + bias).dot(kernel)"""vecs = np.concatenate(vecs, axis=0)mu = vecs.mean(axis=0, keepdims=True)cov = np.cov(vecs.T)u, s, vh = np.linalg.svd(cov)W_inv = np.dot(u, np.diag(s**0.5))W = np.linalg.inv(W_inv.T)return W, -mu

变换及标准化

def transform_and_normalize(vecs, kernel=None, bias=None):"""应用变换,然后标准化"""if not (kernel is None or bias is None):vecs = (vecs + bias).dot(kernel)return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5

语料向量化,计算变换矩阵和偏置项

a_train_vecs, b_train_vecs, train_labels = convert_to_vecs(train_data)
5551/5551 [==============================] - 41s 7ms/step
5551/5551 [==============================] - 39s 7ms/step
a_test_vecs, b_test_vecs, test_labels = convert_to_vecs(test_data)
1478/1478 [==============================] - 8s 5ms/step
1478/1478 [==============================] - 9s 6ms/step
#将训练集和测试集的句子都转化为768维的向量
print(len(a_train_vecs[0]))
print(len(a_test_vecs[0]))
768
768
kernel, bias = compute_kernel_bias([a_train_vecs, b_train_vecs, a_test_vecs, b_test_vecs])

kernel bias计算详解

all_vecs = np.concatenate([a_train_vecs, b_train_vecs, a_test_vecs, b_test_vecs], axis=0)
# all_vecs是(5551+5551+1478+1478)× 768维的矩阵
all_vecs.shape
(14058, 768)

关于768列的每个维度的均值向量 μ \mu μ

mu = all_vecs.mean(axis=0, keepdims=True)
mu.shape
(1, 768)

768列之间的相关系数矩阵 Σ \Sigma Σ

cov = np.cov(all_vecs.T)
cov.shape
(768, 768)

根据SVD分解 Σ = U Λ U ⊤ \Sigma=U\Lambda U^{\top} Σ=UΛU⊤及 Σ = ( W − 1 ) ⊤ W − 1 \Sigma=(W^{-1})^{\top}W^{-1} Σ=(W−1)⊤W−1计算得 W − 1 = U Λ W^{-1}=U\sqrt{\Lambda} W−1=UΛ ​

u, s, vh = np.linalg.svd(cov)
print(u.shape)
print(s.shape)
print(vh.shape)
(768, 768)
(768,)
(768, 768)

求解$ W^{-1}$

W_inv = np.dot(u, np.diag(s**0.5))
W_inv.shape
(768, 768)

求解 W W W

W = np.linalg.inv(W_inv.T)
W.shape
(768, 768)

变换及标准化详解

a_train_vecs = transform_and_normalize(a_train_vecs, kernel, bias)
b_train_vecs = transform_and_normalize(b_train_vecs, kernel, bias)

按照 x i ~ = ( x i − μ ) W \tilde{x_i}=(x_i-\mu)W xi​~​=(xi​−μ)W进行变换

all_vecs = (all_vecs-mu).dot(W)
all_vecs.shape
(14058, 768)

按照 x i ′ = x i Σ x i 2 x_i^{\prime}=\frac{x_i}{\sqrt{\Sigma x_i^2}} xi′​=Σxi2​ ​xi​​进行数据标准化

all_vecs = all_vecs/(all_vecs**2).sum(axis=1,keepdims=True)**0.5
all_vecs.shape
(14058, 768)

计算训练集中句子相关分数

np.corrcoef(train_labels, train_sims)
array([[1.        , 0.71206009],[0.71206009, 1.        ]])
train_sims = (a_train_vecs * b_train_vecs).sum(axis=1)
print(u'训练集的相关系数:%s' % np.corrcoef(train_labels, train_sims)[0, 1])
训练集的相关系数:0.7120600911305668

关于np.corrcorf(X,Y):

返回相关系数矩阵

R ( x , y ) = C o v ( X , Y ) V a r ( X ) V a r ( Y ) = E ( X Y ) − E ( X ) E ( Y ) V a r ( X ) V a r ( Y ) R(x,y)=\frac{Cov(X,Y)}{\sqrt{Var(X)Var(Y)}}=\frac{E(XY)-E(X)E(Y)}{\sqrt{Var(X)Var(Y)}} R(x,y)=Var(X)Var(Y) ​Cov(X,Y)​=Var(X)Var(Y) ​E(XY)−E(X)E(Y)​

#[0,1]意为取第0行第1列的数
np.corrcoef([1,3,5],[2,4,5])[0,1]
0.9819805060619656

r = 13 − 3 ∗ 11 3 1.633 ∗ 1.247 = 0.98198 r = \frac{13-3*\frac{11}{3}}{1.633*1.247}=0.98198 r=1.633∗1.24713−3∗311​​=0.98198

计算测试集中句子相关分数

a_test_vecs = transform_and_normalize(a_test_vecs, kernel, bias)
b_test_vecs = transform_and_normalize(b_test_vecs, kernel, bias)
test_sims = (a_test_vecs * b_test_vecs).sum(axis=1)
print(u'测试集的相关系数:%s' % np.corrcoef(test_labels, test_sims)[0, 1])
测试集的相关系数:0.7745647933327217

【coding】Bert-Whitening细解相关推荐

  1. Silverlight实用窍门系列:35.细解Silverlight冒泡路由事件和注册冒泡路由事件【附带实例源码】...

    Silverlight中的事件分为普通事件和冒泡路由事件,它并没有包括WPF中的隧道路由事件,在本章中将详细讲解冒泡路由事件和如何注册一个冒泡路由事件. 一.细解冒泡路由事件 冒泡路由事件可以比喻为: ...

  2. web前端细解cookie那些事

    web前端细解cookie那些事,在互联网时代,IT行业飞速发展,带动了web前端开发行业的兴趣.由于行业新兴起时间不久,专业人才缺乏,薪资待遇较高,已成为众多IT学子选择就业的首选,今天就为分享一些 ...

  3. WCF从理论到实践(5):Binding细解(转)

    WCF从理论到实践(5):Binding细解 本文的出发点: 通过阅读本文,您能了解以下知识: WCF中的Binding是什么? Binding的组成? Binding Element 的分类? Bi ...

  4. 前端flv.js设置缓冲时间和大小_好程序员web前端细解cookie那些事

    好程序员web前端细解cookie那些事,在互联网时代,IT行业飞速发展,带动了web前端开发行业的兴趣.由于行业新兴起时间不久,专业人才缺乏,薪资待遇较高,已成为众多IT学子选择就业的首选,今天就为 ...

  5. videoleap自带素材_videoleap教程:制作电影帷幕开场效果细解

    大家晚上好,我是Mr.吴 劳模吴又熬夜来给大家更新教程了 今天给大家带来的教程是 --如何制作电影开场的效果 这里我先放一个上周末出去约拍 记录的一个vlog成品杭州印打卡--WABF小分队https ...

  6. BERT embedding 降维--BERT whitening

    利用BERT whitening可以将embedding 比如768维降到256维 def compute_kernel_bias(vecs, n_components=256):"&quo ...

  7. HijackThis日志细解【简明教程增强版】(一)

    转的贴(偶是怕以后看不到了,所以保存下来的),原文章(By 风之咏者)地址:http://bbs.kingsoft.com/viewthread.php?tid=407983&sid=8miH ...

  8. HijackThis日志细解--清净网络(复杂详尽)

    一.说在前面的提示(请原谅我啰嗦) 提示一:本文目的 本文的目的是帮助您进一步解读HijackThis扫描日志.如果您只是想知道HijackThis的使用方法,下面列出的2篇文章可以满足您的要求: 1 ...

  9. 揪出狐狸的尾巴,HijackThis日志细解【附反劫持一般建议】

    HijackThis日志细解[附反劫持一般建议] 一.说在前面的提示(请原谅我啰嗦) 提示一:本文目的 本文的目的是帮助您进一步解读HijackThis扫描日志.如果您只是想知道HijackThis的 ...

最新文章

  1. eclipse搭建 tomcat、
  2. Redis实现消息队列的4种方案
  3. mysql中ifnull函数
  4. 电商系统的商品规格设计方案
  5. [Java]==和equals()的区别(按照数据类型区分)
  6. sql实现寻找中位数(使用sign、case、自定义变量等)
  7. Opengl编程指南第二章:状态管理、几何绘图
  8. 理解_RBAC基础概念_Spring Security OAuth2.0认证授权---springcloud工作笔记113
  9. Dubbo实战快速入门 (转)
  10. 根据已有的WSDL文件进行WebService服务开发和部署
  11. (最通俗易懂的)目标跟踪MOSSE、KCF
  12. 尚硅谷前端视频总结(一)
  13. WEB前端开发快速入门教程
  14. 主动轮廓模型:Snake模型的python实现
  15. 举个栗子~Tableau 技巧(205 ):区域地图中呈现具体位置
  16. scrapy 搜索关键字_基于scrapy框架输入关键字爬取有关贴吧帖子
  17. 使用log4j失误导致系统假死,记录一下
  18. linux bond四网卡绑定,Linux bond 网卡绑定配置教程
  19. TCP 握手没成功怎么办?
  20. Excel序号删除某行之后不连贯?这样做可以智能更新表格序号!

热门文章

  1. 李想又要赴美上市了,高中辍学的他凭什么?
  2. AI又来割韭菜了 280亿估值的寒武纪科创板IPO堪忧?
  3. 堆叠泛化(Stacking Generalization)
  4. 有什么好用的网站导航?
  5. Cartographer学习总结
  6. react native : Implementing unavailable method
  7. Docker Swarm-Docker
  8. JDBC连接数据库 代码及解释说明
  9. 一个简单的留言微信小程序
  10. windows编程之计时器