贝叶斯神经网络 BNN
1. 简介
贝叶斯神经网络不同于一般的神经网络,其权重参数是随机变量,而非确定的值。如下图所示:
也就是说,和传统的神经网络用交叉熵,mse等损失函数去拟合标签值相反,贝叶斯神经网络拟合后验分布。
这样做的好处,就是降低过拟合。
2. BNN模型
BNN 不同于 DNN,可以对预测分布进行学习,不仅可以给出预测值,而且可以给出预测的不确定性。这对于很多问题来说非常关键,比如:机器学习中著名的 Exploration & Exploitation (EE)的问题,在强化学习问题中,agent 是需要利用现有知识来做决策还是尝试一些未知的东西;实验设计问题中,用贝叶斯优化来调超参数,选择下一个点是根据当前模型的最优值还是利用探索一些不确定性较高的空间。比如:异常样本检测,对抗样本检测等任务,由于 BNN 具有不确定性量化能力,所以具有非常强的鲁棒性。
概率建模:
在这里,选择似然分布的共轭分布,这样后验可以分析计算。
比如,beta分布的先验和伯努利分布的似然,会得到服从beta分布的后验。
由于共轭分布,需要对先验分布进行约束。因此,我们尝试使用采用和变分推断来近似后验分布。
神经网络:
使用全连接网络来拟合数据,相当于使用多个全连接网络。
但是神经网络容易过拟合,泛化性差;并且对预测的结果无法给出置信度。
BNN:
把概率建模和神经网络结合起来,并能够给出预测结果的置信度。
先验用来描述关键参数,并作为神经网络的输入。神经网络的输出用来描述特定的概率分布的似然。通过采样或者变分推断来计算后验分布。
同时,和神经网络不同,权重 W 不再是一个确定的值,而是一个概率分布。
BNN建模如下:
假设 NN 的网络参数为 WWW,p(W)p(W)p(W) 是参数的先验分布,给定观测数据 D=X,YD={X,Y}D=X,Y,这里 XXX 是输入数据,YYY 是标签数据。BNN 希望给出以下的分布:
也就是我们预测值为:
P(Y⋆∣X⋆,D)=∫P(Y⋆∣X⋆,W)P(W∣D)dW(1)P\left(Y^{\star} | X^{\star}, D\right)=\int P\left(Y^{\star} | X^{\star}, W\right) P(W | D) d W (1) P(Y⋆∣X⋆,D)=∫P(Y⋆∣X⋆,W)P(W∣D)dW(1)
由于,WWW是随机变量,因此,我们的预测值也是个随机变量。
其中:
P(W∣D)=P(W)P(D∣W)P(D)(2)P(W | D)=\frac{P(W) P(D | W)}{P(D)} (2) P(W∣D)=P(D)P(W)P(D∣W)(2)
这里 P(W∣D)P(W|D)P(W∣D) 是后验分布,P(D∣W)P(D|W)P(D∣W) 是似然函数,P(D)P(D)P(D) 是边缘似然。
从公式(1)中可以看出,用 BNN 对数据进行概率建模并预测的核心在于做高效近似后验推断,而 变分推断 VI 或者采样是一个非常合适的方法。
如果采样的话:
我们通过采样后验分布P(W∣D)P(W \vert \mathcal{D})P(W∣D) 来评估 P(W∣D)P(W \vert \mathcal{D})P(W∣D) , 每个样本计算 f(X∣w)f(X \vert w)f(X∣w), 其中 f 是我们的神经网络。
正是我们的输出是一个分布,而不是一个值,我们可以估计我们预测的不确定度。
3. 基于变分推断的BNN训练
如果直接采样后验概率 p(W∣D)p(W|D)p(W∣D) 来评估 p(Y∣X,D)p(Y|X, D)p(Y∣X,D)的话,存在后验分布多维的问题,而变分推断的思想是使用简单分布去近似后验分布。
表示θ=(μ,σ)\theta = (\mu, \sigma)θ=(μ,σ), 每个权重 wiw_iwi 从正态分布(μi,σi)(\mu_i, \sigma_i)(μi,σi) 中采样。
希望 q(w∣θ)q(w \vert \theta)q(w∣θ) 和 P(w∣D)P(w \vert \mathcal{D})P(w∣D) 相近,并使用 KL 散度来度量这两个分布的距离。
也就是优化:
θ∗=argminθKL[q(w∣θ)∣∣P(w∣D)](3)\theta^* = \underset{\theta}{\mathrm{argmin}} \text{ KL}\left[q(w \vert \theta) \vert \vert P(w \vert \mathcal{D})\right] \; (3) θ∗=θargmin KL[q(w∣θ)∣∣P(w∣D)](3)
进一步推导:
θ∗=argminθKL[q(w∣θ)∣∣P(w∣D)]=argminθEq(w∣θ)[log[q(w∣θ)P(w∣D)]](definition of KL divegence)=argminθEq(w∣θ)[log[q(w∣θ)P(D)P(D∣w)P(w)]](Bayes Theorem)=argminθEq(w∣θ)[log[q(w∣θ)P(D∣w)P(w)]](Drop P(D)because it doesn’t depend on θ)(4)\begin{array}{l} \theta^* &= \underset{\theta}{\mathrm{argmin}} \text{ KL}\left[q(w \vert \theta) \vert \vert P(w \vert \mathcal{D})\right] & \\\\ &= \underset{\theta}{\mathrm{argmin}} \text{ }\mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( w \vert \mathcal{D})}\right]\right] & \text{(definition of KL divegence)} \\\\ &= \underset{\theta}{\mathrm{argmin}} \text{ }\mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta)P(\mathcal{D}) }{P( \mathcal{D} \vert w)P(w)}\right]\right] & \text{(Bayes Theorem)} \\\\ &= \underset{\theta}{\mathrm{argmin}} \text{ }\mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] & \text{(Drop }P(\mathcal{D})\text{ because it doesn't depend on } \theta) \end{array} \;(4) θ∗=θargmin KL[q(w∣θ)∣∣P(w∣D)]=θargmin Eq(w∣θ)[log[P(w∣D)q(w∣θ)]]=θargmin Eq(w∣θ)[log[P(D∣w)P(w)q(w∣θ)P(D)]]=θargmin Eq(w∣θ)[log[P(D∣w)P(w)q(w∣θ)]](definition of KL divegence)(Bayes Theorem)(Drop P(D) because it doesn’t depend on θ)(4)
公式中,
q(w∣θ)q(w|\theta)q(w∣θ) 表示给定正态分布的参数后,权重参数的分布;
P(D∣w)P(D|w)P(D∣w) 表示给定网络参数后,观测数据的似然;
P(w)P(w)P(w) 表示权重的先验,这部分可以作为模型的正则化。
并且使用
L=−Eq(w∣θ)[log[q(w∣θ)P(D∣w)P(w)]](5)\mathcal{L} = - \mathbb{E}_{q(w \vert \theta)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] \;(5) L=−Eq(w∣θ)[log[P(D∣w)P(w)q(w∣θ)]](5)
来表示变分下界ELBO, 也就是公式(4)等价于最大化ELBO:
L=∑ilogq(wi∣θi)−∑ilogP(wi)−∑jlogP(yj∣w,xj)(6)\mathcal{L} = \sum_i \log q(w_i \vert \theta_i) - \sum_i \log P(w_i) - \sum_j \log P(y_j \vert w, x_j) \;(6) L=i∑logq(wi∣θi)−i∑logP(wi)−j∑logP(yj∣w,xj)(6)
其中,D={(x,y)}D =\{ (x, y)\}D={(x,y)}
我们需要对公式(4)中的期望进行求导,但是,这里,我们使用对权重进行重参数的技巧:
wi=μi+σi×ϵi(7)w_i = \mu_i + \sigma_i \times \epsilon_i \; (7) wi=μi+σi×ϵi(7)
其中, ϵi∼N(0,1)\epsilon_i \sim \mathcal{N}(0,1)ϵi∼N(0,1).
于是,用 ϵ\epsilonϵ代 替 www 后有:
∂∂θEq(ϵ)[log[q(w∣θ)P(D∣w)P(w)]]=Eq(ϵ)[∂∂θlog[q(w∣θ)P(D∣w)P(w)]](8)\frac{\partial}{\partial \theta}\mathbb{E}_{q(\epsilon)}\left[ \log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] =\mathbb{E}_{q(\epsilon)}\left[ \frac{\partial}{\partial \theta}\log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]\right] \; (8) ∂θ∂Eq(ϵ)[log[P(D∣w)P(w)q(w∣θ)]]=Eq(ϵ)[∂θ∂log[P(D∣w)P(w)q(w∣θ)]](8)
也就是说,我们可以通过 多个不同的 ϵ∼N(0,1)\epsilon \sim \mathcal{N}(0,1)ϵ∼N(0,1) ,求取∂∂θlog[q(w∣θ)P(D∣w)P(w)]\frac{\partial}{\partial \theta}\log\left[\frac{ q(w \vert \theta) }{P( \mathcal{D} \vert w)P(w)}\right]∂θ∂log[P(D∣w)P(w)q(w∣θ)] 的平均值,来近似 KL 散度对 θ\thetaθ 的求导。
此外,除了对 www 进行重采样之外,为了保证 θ\thetaθ 参数取值范围包含这个实轴,对 σ\sigmaσ 进行重采样,可以令,
σ=log(1+eρ)(9)\sigma = \log (1 + e^{\rho}) \;\;\; (9) σ=log(1+eρ)(9)
然后,θ=(μ,ρ)\theta = (\mu, \rho)θ=(μ,ρ),这里的 θ\thetaθ 已经和原来定义的θ=(μ,σ)\theta = (\mu, \sigma)θ=(μ,σ) 不一样了。
4. BNN实践
算法:
- 从 N(μ,log(1+eρ))N(\mu, log(1+e^\rho))N(μ,log(1+eρ)) 中采样,获得 www;
- 分别计算 logq(w∣θ)\log q(w|\theta)logq(w∣θ)、 logp(w)\log p(w)logp(w)、 logp(y∣w,x)\log p(y|w,x)logp(y∣w,x).
其中,计算 logp(y∣w,x)\log p(y|w,x)logp(y∣w,x) 实际计算 logp(y∣ypred)\log p(y|y_{pred})logp(y∣ypred), ypred=w∗xy_{pred} = w*xypred=w∗x.
也就可以得到 L=∑ilogq(wi∣θi)−∑ilogP(wi)−∑jlogP(yj∣w,xj)\mathcal{L} = \sum_i \log q(w_i \vert \theta_i) - \sum_i \log P(w_i) - \sum_j \log P(y_j \vert w, x_j)L=∑ilogq(wi∣θi)−∑ilogP(wi)−∑jlogP(yj∣w,xj)。 - 重复更新参数θ’=θ−α∇θL\theta’ = \theta -\alpha \nabla_\theta \mathcal{L}θ’=θ−α∇θL.
Pytorch实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as pltclass Linear_BBB(nn.Module):"""Layer of our BNN."""def __init__(self, input_features, output_features, prior_var=1.):"""Initialization of our layer : our prior is a normal distributioncentered in 0 and of variance 20."""# initialize layerssuper().__init__()# set input and output dimensionsself.input_features = input_featuresself.output_features = output_features# initialize mu and rho parameters for the weights of the layerself.w_mu = nn.Parameter(torch.zeros(output_features, input_features))self.w_rho = nn.Parameter(torch.zeros(output_features, input_features))#initialize mu and rho parameters for the layer's biasself.b_mu = nn.Parameter(torch.zeros(output_features))self.b_rho = nn.Parameter(torch.zeros(output_features)) #initialize weight samples (these will be calculated whenever the layer makes a prediction)self.w = Noneself.b = None# initialize prior distribution for all of the weights and biasesself.prior = torch.distributions.Normal(0,prior_var)def forward(self, input):"""Optimization process"""# sample weightsw_epsilon = Normal(0,1).sample(self.w_mu.shape)self.w = self.w_mu + torch.log(1+torch.exp(self.w_rho)) * w_epsilon# sample biasb_epsilon = Normal(0,1).sample(self.b_mu.shape)self.b = self.b_mu + torch.log(1+torch.exp(self.b_rho)) * b_epsilon# record log prior by evaluating log pdf of prior at sampled weight and biasw_log_prior = self.prior.log_prob(self.w)b_log_prior = self.prior.log_prob(self.b)self.log_prior = torch.sum(w_log_prior) + torch.sum(b_log_prior)# record log variational posterior by evaluating log pdf of normal distribution defined by parameters with respect at the sampled valuesself.w_post = Normal(self.w_mu.data, torch.log(1+torch.exp(self.w_rho)))self.b_post = Normal(self.b_mu.data, torch.log(1+torch.exp(self.b_rho)))self.log_post = self.w_post.log_prob(self.w).sum() + self.b_post.log_prob(self.b).sum()return F.linear(input, self.w, self.b)class MLP_BBB(nn.Module):def __init__(self, hidden_units, noise_tol=.1, prior_var=1.):# initialize the network like you would with a standard multilayer perceptron, but using the BBB layersuper().__init__()self.hidden = Linear_BBB(1,hidden_units, prior_var=prior_var)self.out = Linear_BBB(hidden_units, 1, prior_var=prior_var)self.noise_tol = noise_tol # we will use the noise tolerance to calculate our likelihooddef forward(self, x):# again, this is equivalent to a standard multilayer perceptronx = torch.sigmoid(self.hidden(x))x = self.out(x)return xdef log_prior(self):# calculate the log prior over all the layersreturn self.hidden.log_prior + self.out.log_priordef log_post(self):# calculate the log posterior over all the layersreturn self.hidden.log_post + self.out.log_postdef sample_elbo(self, input, target, samples):# we calculate the negative elbo, which will be our loss function#initialize tensorsoutputs = torch.zeros(samples, target.shape[0])log_priors = torch.zeros(samples)log_posts = torch.zeros(samples)log_likes = torch.zeros(samples)# make predictions and calculate prior, posterior, and likelihood for a given number of samplesfor i in range(samples):outputs[i] = self(input).reshape(-1) # make predictionslog_priors[i] = self.log_prior() # get log priorlog_posts[i] = self.log_post() # get log variational posteriorlog_likes[i] = Normal(outputs[i], self.noise_tol).log_prob(target.reshape(-1)).sum() # calculate the log likelihood# calculate monte carlo estimate of prior posterior and likelihoodlog_prior = log_priors.mean()log_post = log_posts.mean()log_like = log_likes.mean()# calculate the negative elbo (which is our loss function)loss = log_post - log_prior - log_likereturn lossdef toy_function(x):return -x**4 + 3*x**2 + 1# toy dataset we can start with
x = torch.tensor([-2, -1.8, -1, 1, 1.8, 2]).reshape(-1,1)
y = toy_function(x)net = MLP_BBB(32, prior_var=10)
optimizer = optim.Adam(net.parameters(), lr=.1)
epochs = 2000
for epoch in range(epochs): # loop over the dataset multiple timesoptimizer.zero_grad()# forward + backward + optimizeloss = net.sample_elbo(x, y, 1)loss.backward()optimizer.step()if epoch % 10 == 0:print('epoch: {}/{}'.format(epoch+1,epochs))print('Loss:', loss.item())
print('Finished Training')# samples is the number of "predictions" we make for 1 x-value.
samples = 100
x_tmp = torch.linspace(-5,5,100).reshape(-1,1)
y_samp = np.zeros((samples,100))
for s in range(samples):y_tmp = net(x_tmp).detach().numpy()y_samp[s] = y_tmp.reshape(-1)
plt.plot(x_tmp.numpy(), np.mean(y_samp, axis = 0), label='Mean Posterior Predictive')
plt.fill_between(x_tmp.numpy().reshape(-1), np.percentile(y_samp, 2.5, axis = 0), np.percentile(y_samp, 97.5, axis = 0), alpha = 0.25, label='95% Confidence')
plt.legend()
plt.scatter(x, toy_function(x))
plt.title('Posterior Predictive')
plt.show()
这里是重复计算100次的平均值和100次平均值的97.5%大和2.5%小的区域线图(即置信度95%)。
最近开通了个公众号,主要分享深度学习相关内容,推荐系统,风控等算法相关的内容,感兴趣的伙伴可以关注下。
参考:
- 变分推断;
- Weight Uncertainty in Neural Networks Tutorial;
- Bayesian Neural Networks;
- 原论文
贝叶斯神经网络 BNN相关推荐
- 贝叶斯神经网络BNN
反向传播网络在优化完毕后,其权重是一个固定的值,而贝叶斯神经网络把权重看成是服从均值为 μ ,方差为 δ 的高斯分布,每个权重服从不同的高斯分布,反向传播网络优化的是权重,贝叶斯神经网络优化的是权重的 ...
- 贝叶斯神经网络计算核裂变碎片产额
作者丨庞龙刚 单位丨华中师范大学 研究方向丨高能核物理.人工智能 今天介绍一篇北京大学物理系使用贝叶斯神经网络计算核裂变碎片产额的文章.这篇文章发表在 PRL 上,业内同行都很感兴趣.这里对我们大同行 ...
- 贝叶斯神经网络的辩论
贝叶斯概率体系的研究有一段时间了,目前在推进贝叶斯神经网络,看到这篇文章的辩论,这里保存下. https://mp.weixin.qq.com/s?__biz=MzI5NTIxNTg0OA==& ...
- 结合随机微分方程,多大Duvenaud团队提出无限深度贝叶斯神经网络
©作者 | 小舟.陈萍 来源 | 机器之心 来自多伦多大学和斯坦福大学的研究者开发了一种在连续深度贝叶斯神经网络中进行近似推理的实用方法. 把神经网络的限制视为无限多个残差层的组合,这种观点提供了一种 ...
- 贝叶斯神经网络对梯度攻击的鲁棒性
©PaperWeekly 原创 · 作者|尹娟 学校|北京理工大学博士生 研究方向|随机过程.复杂网络单位 引言 贝叶斯神经网络(BNN)在最近几年得到了一定的重视,因为其具有一定的推断能力.BNN ...
- 贝叶斯神经网络最新综述
©PaperWeekly 原创 · 作者|尹娟 学校|北京理工大学博士生 研究方向|随机过程.复杂网络 论文标题:Bayesian Neural Networks: An Introduction a ...
- 贝叶斯神经网络----从贝叶斯准则到变分推断
前言 在认识贝叶斯神经网络之前,建议先复习联合概率,条件概率,边缘概率,极大似然估计,最大后验估计,贝叶斯估计这些基础 极大似然估计 一个神经网络模型可以视为一个条件分布模型 p ( y ∣ x , ...
- 07. 贝叶斯神经网络
算法思路 普通的神经网络的权值是确定的,而贝叶斯神经网络的权值是不确定的,他服从于一个概率分布,这便是贝叶斯神经网络和普通神经网络的差别. 可以简单认为,贝叶斯神经网络是无穷个神经网络的融合,不过给每 ...
- pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接
pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...
- 贝叶斯深度神经网络_深度学习为何胜过贝叶斯神经网络
贝叶斯深度神经网络 Recently I came across an interesting Paper named, "Deep Ensembles: A Loss Landscape ...
最新文章
- RIPv1与RIPv2互通
- 学习javascript数据结构(三)——集合
- GitLab修改用户密码
- 会计电算化的重要物质基础计算机和,湖北工业大学工程技术学院会计电算化管理办法...
- aws lambda_适用于无服务器Java开发人员的AWS Lambda:它为您提供了什么?
- rocketmq 消息指定_进大厂必备的RocketMQ你会吗?
- TypeScript函数
- 多规则策略如何筛选|视频版
- 完全备份、差异备份以及增量备份的区别
- 使用java语言操作,如何来实现MySQL中Blob字段的存取
- arcgis10之将多个shp文件合并成一个shp文件
- 大数据开发工程师是做什么的?岗位要求高吗?
- Python爬虫-豆瓣电影排行榜TOP250
- 信道容量受哪三个要素_连续信道容量将受到“三要素”的限制,其“三要素”是...
- java对外接口安全问题_怎么保证对外暴露接口的安全性(调用频率限制)
- 基于JAVA景区售票系统设计与实现 开题报告
- 信息抽取之实体消歧,统一
- C语言——数组指针篇
- System memory 249364480 must be at least 471859200
- 延锋安道拓:简化工作流程 实现研发数据外发安全可控