背景介绍:(了解采样的可以跳过)

1)为什么需要采样:

简单的分布,比如高斯、exponential、gamma等等的样本都可以直接用numpy.random生成,但复杂的分布需要采样器生成。在贝叶斯、概率编程里面,有很多复杂的分布,而贝叶斯更新需要对这些复杂分布再采样。

2)最简单的3种采样:

  • CDF采样:(把pdf转成cdf,然后对0-1区间均匀采样)
  • 拒绝采样:

蓝线是真实分布,红线是一个处处y坐标值大于蓝线的函数(正比于一个简单分布,比如均匀分布,或者高斯分布,或者一个分段函数),对红线的分布采样得到样本x,然后计算蓝线(x)/红线(x)作为接受这个样本的概率。

  • 重要性采样

类似拒绝采样,不需要红线函数处处大于蓝线,但每个样本x的权重是蓝线(x)/红线(x)

3)以上三种采样的问题:

  • cdf采样需要一个数值积分,所以在高维空间采样的时候计算量太大
  • 拒绝采样需要先构造一个总是值大于自己的函数
  • 重要性采样在对分布情况不太清楚的时候采样方差可能很大,即大部分样本的权重都很低
  • 都或多或少需要对分布本身有一定了解,在高维空间不够高效

4)Metropolis-Hastings:

特点:用一个多元正态分布随机游走,不需要对分布本身有太多了解,高维空间时有效,分布本身不需要积分=1,基本只需要保证f(x)在-∞和+∞处=0,0<f(x)<upper bound就行了。

方法:初始点=>随机游走=>如果

则接受这个新的样本,否则按照
的概率接受这个采样。

问题:随机游走的效率仍然不够高。

代码:

def binormal_draw(xprev,beta=1):mean = [0, 0]cov = [[beta**2,0],[0,(1.5*beta)**2]]binormal = scipy.stats.multivariate_normal(mean,cov)return xprev + binormal.rvs()def metropolis(F, qdraw, nsamp, x_init, burnin, thinning=2, beta=1):samples=np.empty((nsamp,2))x_prev = x_initaccepted = 0j = 0for i in range((nsamp+burnin)*thinning):x_star = qdraw(x_prev, beta)logp_star = np.log(F(x_star[0], x_star[1]))logp_prev = np.log(F(x_prev[0], x_prev[1]))logpdfratio_p = logp_star-logp_prevu = np.random.uniform()if np.log(u) <= logpdfratio_p:x_prev = x_starif i >= burnin*thinning and i%thinning == 0:samples[j] = x_starj += 1accepted += 1else:#we always get a sampleif i >= burnin*thinning and i%thinning == 0:samples[j]= x_prevj += 1return samples, accepted

可视化:

http://elevanth.org/blog/2017/11/28/build-a-better-markov-chain/​elevanth.org


HMC:

1)直观理解原理:把MCMC的样本随机游走改成在一个势场中具有动能的质点,势场由分布函数f(x)决定,初始动能随机给定,这个质点在运动时间t后的位置记录下来,并且作为下次采样的初始点。

2)相比MH的优点:更新样本点的时候使用了梯度,所以对复杂分布采样比随机试的速度要快。

2)公式、可视化:参考下列文章

Markov Chains: Why Walk When You Can Flow?​elevanth.orgXinyu Chen:如何简单地理解「哈密尔顿蒙特卡洛 (HMC)」?​zhuanlan.zhihu.com

https://theclevermachine.wordpress.com/2012/11/18/mcmc-hamiltonian-monte-carlo-a-k-a-hybrid-monte-carlo/​theclevermachine.wordpress.com

MCMC: Hamiltonian Monte Carlo (a.k.a. Hybrid Monte Carlo)​theclevermachine.wordpress.com

https://arxiv.org/pdf/1701.02434.pdf​arxiv.orgHamiltonian Monte Carlo explained by Alex Rogozhnikov​arogozhnikov.github.io

3)代码:

def HMC(F, u0, n_iter, N_iter, h=0.01): # part 1: 初始化一个orbit的变量,第一行是初始点u0的坐标orbit = torch.zeros((n_iter+1, 2))orbit[0] = u0.detach()u = orbit[0].unsqueeze(0)shape = u.size()# part 2: 循环n_iter次,每次初始化一个v0和u0,然后让leapfrog走N_iter次,每次时长h默认=0.01for k in tqdm(range(n_iter)):v0 = torch.randn(size=u0.shape)u0 = torch.randn(size=u0.shape)*3u, v = leapfrog(F, u0, v0, h, N_iter,shape)u0 = orbit[k]a = float(ratio(F, u0, v0, u, v, shape))r = np.random.rand()if r < a:orbit[k+1] = uelse:orbit[k+1] = u0return orbit #[10:, :]# part 3: leapfrog算法,此外还有velocity verlet等等都可以尝试,但不要用euler
def leapfrog(F, u, v, h, N_iter,shape):v = v - h/2 * grad(F, u)for i in range(N_iter-1):u = u + h * vv = v - h * grad(F, u)u = u + h * vv = v - h/2 * grad(F, u)return u, v# part 4: 对函数F求梯度
def grad(F, u):u = u.detach()if u.requires_grad == False:u = u.requires_grad_()output = -torch.log(F(u))output.backward()ugrad = u.grad.squeeze(0)u = u.squeeze(0)return ugrad# part 5: 细节处理
def unsqueeze(tensor, shape):if len(list(tensor.size())) < len(list(shape)):tensor = tensor.unsqueeze(0)return tensor# part 6: 由于leapfrog计算结果仍然不是绝对的稳定,所以接受概率是min{运行前后总能量的比值,1}
#,大多数情况下都≈1
def ratio(F, u0, v0, u, v, shape):u0 = unsqueeze(u0, shape)u = unsqueeze(u, shape)v0 = unsqueeze(v0, shape)v = unsqueeze(v, shape)w0 = - torch.log(F(u0)) + 0.5*torch.mm(v0, torch.t(v0))w1 = - torch.log(F(u)) + 0.5*torch.mm(v, torch.t(v))return torch.exp(w0-w1)

4)调参细节:

一共有5组主要的超参要调,v0的标准差,u0的位置和标准差,时间步长h,总步长数N_iter,burning。

5)效果:

此处暂时无图。。不过采样的速度总体比Metropolis Hastings要稳定,样本质量较高,使用了梯度所以效率高。

6)进阶:NUTS采样

The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo​arxiv.org

如何自己去写一个鼠标驱动_为什么要用哈密顿采样器(Hamiltonian Monte Carlo),以及如何自己写一个...相关推荐

  1. rust实现一个mysql驱动_使用Rust编写用户态驱动程序

    概览 在云计算技术的发展史上,如何提高单个服务器的并发度,一直是热门的研究课题.在20年前,就有著名的"C10K"问题,即如何利用单个服务器每秒应对10K个客户端的同时访问.这么多 ...

  2. 判断一组多选框至少有一个被选中_想不想拥有自己的篆刻印章?那就PS一个吧...

    篆刻本身是书法和国画中必用的元素之一,但是在摄影后期中如果你想将作品做成仿国画效果那就离不开篆刻.不是每个人都有自己的篆刻,如果你不涉及到国画或者书法,我猜你是没有篆刻的. 那今天亮亮老师就带你学习一 ...

  3. delphi dbgrideh 遍历每一个单元格_利用财务函数制作贷款计算器,让你了解还款的每一个细节...

    大家好我是践行计算机教育刘老师,今天跟大家分享利用财务函数可以制作贷款计算机,以方便了解还款过程中的每一个细节. 贷款示例效果图 制作贷款计算器 制作贷款计算器-计算每月还款额 在C6单元格中输入公式 ...

  4. correlation 蒙特卡洛_蒙特卡洛模拟法及其matlab案例(Monte Carlo simulation method and its matlab case).doc...

    蒙特卡洛模拟法及其matlab案例(Monte Carlo simulation method and its matlab case) 蒙特卡洛模拟法及其matlab案例(Monte Carlo s ...

  5. C语言做每点击鼠标一下变量加一,用C语言写一个鼠标连点器!再也不要担心红包抢不过了~...

    C语言是面向过程的,而C++是面向对象的 C和C++的区别: C是一个结构化语言,它的重点在于算法和数据结构.C程序的设计首要考虑的是如何通过一个过程,对输入(或环境条件)进行运算处理得到输出(或实现 ...

  6. linux鼠标驱动程序,Linux usb子系统(一) _写一个usb鼠标驱动

    USB总线是一种典型的热插拔的总线标准,由于其优异的性能几乎成为了当下大小设备中的标配. USB的驱动可以分为3类:SoC的USB控制器的驱动,主机端USB设备的驱动,设备上的USB Gadget驱动 ...

  7. cesium鼠标左键获取经纬度_用C语言写一个鼠标连点器!再也不要担心红包抢不过了~...

    C语言是面向过程的,而C++是面向对象的 C和C++的区别: C是一个结构化语言,它的重点在于算法和数据结构.C程序的设计首要考虑的是如何通过一个过程,对输入(或环境条件)进行运算处理得到输出(或实现 ...

  8. 实战!手把手教你如何编写一个Linux驱动并写一个支持物联网的LED演示demo

    目录 一.开发环境 二. 准备工作: 1. 创建一个项目工程目录 2. 创建输出与目标目录 3.头文件目录 4. 建立源代码src目录 5. 使用git管理你的项目 三.编写LED驱动 三.一 准备工 ...

  9. mysql与php驱动程序_用PHP和MySQL构建一个数据库驱动的网站_php

    在我们目前的情况下,我们所需要的列是Jokes表中的JokeText列以及Authors表中的Name列和Email列.Jokes表和Authors表的关联条件是Jokes表中的AID列的值等于Aut ...

最新文章

  1. PCA、LDA、MDS、LLE、TSNE等降维算法的Python实现
  2. php人民币转换,PHP字符串转换RMB形式数字
  3. maven项目 ant_将大型项目从Ant迁移到Maven
  4. .sdp文件格式介绍
  5. GitHub 5.9K,目标检测、跟踪、关键点全覆盖的年度开源项目来了!
  6. 初中毕业学计算机在哪学,初中毕业要学计算机要去哪个里学呢
  7. 终于下决心写一写自己的博客了!
  8. Python代码覆盖性测试入门
  9. BeagleBone Black 板第三课:Debian7.5系统安装和远程控制BBB板
  10. 电子类经典书籍汇总(转 )
  11. SAKAI OAE汉化
  12. 2022-2027年中国石油装备制造市场竞争态势及行业投资前景预测报告
  13. C语言——求2-1000之间的素数,每行打印8个
  14. java jmf mp3,java播发mp3(不用jmf)
  15. java -cp 与 java -Djava.ext.dirs的区别与坑
  16. IOS 发布被拒 PLA 1.2问题 整个过程介绍 02 个人账户升级公司账户
  17. Missing Tag Identification in COTS RFID Systems: Bridging the Gap between Theory and Practice 理解+笔记
  18. 对参考文献格式的一些举例
  19. c语言指定外设访问宽度 强制,《C语言程序设计》第2章 简单的C程序设计.ppt
  20. python正则表达式删除指定符号及其中的内容

热门文章

  1. 浮点加法器计算机组成原理,计算机组成原理 第二章运算方法与运算器
  2. Win10系统如何退出桌面磁贴功能
  3. TIM怎么显示每条信息的时间
  4. js滚动,滑动,幻灯片,轮播,swipe js滚动,滑动,幻灯片,轮播
  5. Shiro介绍及主要流程
  6. php一句话怎么写_PHP一句话木马后门
  7. windows分屏_windows内到底藏了多少好东西?
  8. JVM001_类文件结构
  9. java时间聚类_mongodb 按照时间聚类 java
  10. python怎么引入os模块的函数_Python里的OS模块常用函数说明