量化分析师的Python日记【Q Quant兵器谱之偏微分方程2】
这是量化分析师的偏微分方程系列的第二篇,在这一篇中我们将解决上一篇显式格式留下的稳定性问题。本篇将引入隐式差分算法,读者可以学到:
- 隐式差分格式描述
- 三对角矩阵求解
- 如何使用
scipy
加速算法实现在完成两天的基础学习之后,在下一天中,我们将把已经学到的知识运用到金融定价领域最重要的方程之一:Black - Shcoles - Merton 偏微分方差
注意: 下文中需要的自建库 utilities,可以从链接 https://uqer.io/community/share/568dfcbe228e5b18e3ba2980 中克隆得到。如何保存library,请见:https://uqer.io/help/faq/#什么是Library
from CAL.PyCAL import * from matplotlib import pylab import seaborn as sns import numpy as np np.set_printoptions(precision = 4) font.set_size(20)def initialCondition(x):return 4.0*(1.0 - x) * x
1. 隐式差分格式
像上一天一样,我们从差分格式的数学表述开始。隐式格式与显式格式的区别,在于我们时间方向选择的基准点。显式格式使用k,而隐式格式选择k+1:
\begin{align}\frac{\partial u(x_j, \tau_{k+1})}{\partial\tau} &= \frac{u_{j,k+1} - u_{j,k}}{\Delta \tau} + O(\Delta \tau) \\\\\frac{\partial^2 u(x_j, \tau_{k+1})}{\partial x^2} &= \frac{u_{j-1,k+1} - 2u_{j,k+1} + u_{j+1,k+1}}{\Delta x^2} + O(\Delta x^2) \\\\\end{align}
剩下的推到过程我完全一样,我们看到无论隐式格式还是显式格式,它们的截断误差是一样的:
\begin{align}u_{\tau}(x_j,\tau_{k+1}) - \kappa u_{xx}(x_j,\tau_{k+1}) &= 0 \\\\\frac{u_{j,k+1} - u_{j,k}}{\Delta \tau} - \kappa \frac{u_{j-1,k+1} - 2u_{j,k+1} + u_{j+1,k+1}}{\Delta x^2} &= O(\Delta \tau) + O(\Delta x^2)\end{align}
用离散值Uj,k替换uj,k,我们得到差分方程:
\begin{align}&\frac{U_{j,k+1} - U_{j,k}}{\Delta \tau} - \kappa \frac{U_{j-1,k+1} - 2U_{j,k+1} + U_{j+1,k+1}}{\Delta x^2} &= 0, \\\\\Rightarrow& \quad U_{j,k+1} - U_{j,k} - \frac{\kappa\Delta \tau}{\Delta x^2}(U_{j-1,k+1} - 2U_{j,k+1} + U_{j+1,+1k}) &= 0, \\\\\Rightarrow& \quad U_{j,k+1} - U_{j,k} - \rho(U_{j-1,k+1} - 2U_{j,k+1} + U_{j+1,k+1}) &= 0.\end{align}
最后,到这里我们得到一个迭代方程组:
−ρUj−1,k+1+(1+2ρ)Uj,k+1−ρUj+1,k+1=Uj,k,1≤j≤N−1,0≤k≤M−1其中 ρ=κΔτΔx2。
N = 500 # x方向网格数 M = 500 # t方向网格数T = 1.0 X = 1.0xArray = np.linspace(0,X,N+1) yArray = map(initialCondition, xArray)starValues = yArray U = np.zeros((N+1,M+1)) U[:,0] = starValues
dx = X / N dt = T / M kappa = 1.0 rho = kappa * dt / dx / dx
1.1 矩阵求解(
TridiagonalSystem
)虽然看上去形式只是变了一点,但是求解的问题有很大的变化。在每个时间点上,我们需要求解如下的一个线性方程组:
AUk+1=Uk这里 A为:
[\mathbf{A} = \left(
1+2ρ−ρ0−ρ1+2ρ⋱⋯⋯−ρ⋱−ρ0⋯−ρ1+2ρ\right) .]
幸运的是,这个是个三对角矩阵,可以很简单的利用Gauss消去法求解。我们这里不会详细讨论算法的描述,细节都可以在下面的
python
类TridiagonalSystem
中了解到:class TridiagonalSystem:def __init__(self, udiag, cdiag, ldiag):'''三对角矩阵:udiag -- 上对角线cdiag -- 对角线ldiag -- 下对角线'''assert len(udiag) == len(cdiag)assert len(cdiag) == len(ldiag)self.udiag = udiagself.cdiag = cdiagself.ldiag = ldiagself.length = len(self.cdiag)def solve(self, rhs):'''求解以下方程组A \ dot x = rhs'''assert len(rhs) == len(self.cdiag)udiag = self.udiag.copy()cdiag = self.cdiag.copy()ldiag = self.ldiag.copy()b = rhs.copy()
# 消去下对角元for i in range(1, self.length):cdiag[i] -= udiag[i-1] * ldiag[i] / cdiag[i-1]b[i] -= b[i-1] * ldiag[i] / cdiag[i-1]# 从最后一个方程开始求解x = np.zeros(self.length)x[self.length-1] = b[self.length - 1] / cdiag[self.length - 1]for i in range(self.length - 2, -1, -1):x[i] = (b[i] - udiag[i]*x[i+1]) / cdiag[i]return xdef multiply(self, x):'''矩阵乘法:rhs = A \dot x'''assert len(x) == len(self.cdiag)rhs = np.zeros(self.length) rhs[0] = x[0] * self.cdiag[0] + x[1] * self.udiag[0]for i in range(1, self.length - 1):rhs[i] = x[i-1] * self.ldiag[i] + x[i] * self.cdiag[i] + x[i+1] * self.udiag[i]rhs[self.length - 1] = x[self.length - 2] * self.ldiag[self.length - 1] + x[self.length - 1] * self.cdiag[self.length - 1]return rhs
1.2 隐式格式求解
for k in range(0, M):udiag = - np.ones(N-1) * rholdiag = - np.ones(N-1) * rhocdiag = np.ones(N-1) * (1.0 + 2. * rho)mat = TridiagonalSystem(udiag, cdiag, ldiag)rhs = U[1:N,k]x = mat.solve(rhs)U[1:N, k+1] = xU[0][k+1] = 0.U[N][k+1] = 0.
from lib.utilities import plotLines plotLines([U[:,0], U[:, int(0.10/ dt)], U[:, int(0.20/ dt)], U[:, int(0.50/ dt)]], xArray, title = u'一维热传导方程', xlabel = '$x$', ylabel = r'$U(\dot, \tau)$', legend = [r'$\tau = 0.$', r'$\tau = 0.10$', r'$\tau = 0.20$', r'$\tau = 0.50$'])
from lib.utilities import plotSurface tArray = np.linspace(0, 0.2, int(0.2 / dt) + 1) tGrids, xGrids = np.meshgrid(tArray, xArray)plotSurface(xGrids, tGrids, U[:,:int(0.2 / dt) + 1], title = u"热传导方程 $u_\\tau = u_{xx}$,隐式格式($\\rho = 50$)", xlabel = "$x$", ylabel = r"$\tau$", zlabel = r"$U$")
2. 继续组装
像我们在显示格式那一节介绍的同样做法,我们把之前的代码整合起来,归集与一个完整的类
ImplicitEulerScheme
中:from lib.utilities import HeatEquation
上面的代码(使用
library
功能,关于该功能的具体介绍请见帮助 — Library是干什么的)导入我们在上一期中已经定义过的类HeatEquation
,避免代码重复。class ImplicitEulerScheme: 2def __init__(self, M, N, equation): 3self.eq = equation 4self.dt = self.eq.T / M 5self.dx = self.eq.X / N 6self.U = np.zeros((N+1, M+1)) 7self.xArray = np.linspace(0,self.eq.X,N+1) 8self.U[:,0] = map(self.eq.ic, self.xArray) 9self.rho = self.eq.kappa * self.dt / self.dx / self.dx 10self.M = M 11self.N = N 1213def roll_back(self): 14for k in range(0, self.M): 15udiag = - np.ones(self.N-1) * self.rho 16ldiag = - np.ones(self.N-1) * self.rho 17cdiag = np.ones(self.N-1) * (1.0 + 2. * self.rho) 1819mat = TridiagonalSystem(udiag, cdiag, ldiag) 20rhs = self.U[1:self.N,k] 21x = mat.solve(rhs) 22self.U[1:self.N, k+1] = x 23self.U[0][k+1] = self.eq.bcl(self.xArray[0]) 24self.U[self.N][k+1] = self.eq.bcr(self.xArray[-1]) 2526def mesh_grids(self): 27tArray = np.linspace(0, self.eq.T, M+1) 28tGrids, xGrids = np.meshgrid(tArray, self.xArray) 29return tGrids, xGrids
然后我们可以使用下面的三行简单调用完成功能:
ht = HeatEquation(1.,X, T) scheme = ImplicitEulerScheme(M,N, ht) scheme.roll_back() scheme.U
array([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00],[ 7.9840e-03, 7.2843e-03, 6.9266e-03, ..., 3.8398e-07,3.7655e-07, 3.6926e-07],[ 1.5936e-02, 1.4567e-02, 1.3852e-02, ..., 7.6795e-07,7.5308e-07, 7.3851e-07],..., [ 1.5936e-02, 1.4567e-02, 1.3852e-02, ..., 7.6795e-07,7.5308e-07, 7.3851e-07],[ 7.9840e-03, 7.2843e-03, 6.9266e-03, ..., 3.8398e-07,3.7655e-07, 3.6926e-07],[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,0.0000e+00, 0.0000e+00]])
3. 使用
scipy
加速
软件工程行业里有句老话,叫做:“不要重复发明轮子!”。实际上,之前的代码里面,我们就造了自己的轮子:
TridiagonalSystem
。三对角矩阵作为最最常见的稀疏矩阵,关于它的线性方程组求解算法实际上早已为业界熟知,也已经有很多库内置了工业级别强度实现。这里我们取scipy
作为例子,来展示使用外源库实现的好处:
- 更加稳健的算法: 知名库算法由于使用者广泛,有更大的概率发现一些极端情形下的bug。库作者可以根据用户反馈,及时调整算法;
- 更高的性能: 由于库的使用更为广泛,库作者有更大的动力去使用各种技术去提高算法的性能:例如使用更高效的语言实现,例如C。scipy中的情形就是一例。
- 持续的维护: 库的受众范围广,社区的力量会推动库作者持续维护。
下面的代码展示,如何使用
scipy
中的solve_banded
算法求解三对角矩阵:import scipy as sp from scipy.linalg import solve_bandedA = np.zeros((3, 5)) A[0, :] = np.ones(5) * 1. # 上对角线 A[1, :] = np.ones(5) * 3. # 对角线 A[2, :] = np.ones(5) * (-1.) # 下对角线b = [1.,2.,3.,4.,5.] x = solve_banded ((1,1), A,b) print 'x = A^-1b = ',x
我们使用上面的算法替代我们之前的
TridiagonalSystem
,import scipy as sp from scipy.linalg import solve_banded for k in range(0, M):udiag = - np.ones(N-1) * rholdiag = - np.ones(N-1) * rhocdiag = np.ones(N-1) * (1.0 + 2. * rho)mat = np.zeros((3,N-1))mat[0,:] = udiagmat[1,:] = cdiagmat[2,:] = ldiagrhs = U[1:N,k]x = solve_banded ((1,1), mat,rhs)U[1:N, k+1] = xU[0][k+1] = 0.U[N][k+1] = 0.
plotLines([U[:,0], U[:, int(0.10/ dt)], U[:, int(0.20/ dt)], U[:, int(0.50/ dt)]], xArray, title = u'一维热传导方程,使用scipy', xlabel = '$x$', ylabel = r'$U(\dot, \tau)$', legend = [r'$\tau = 0.$', r'$\tau = 0.10$', r'$\tau = 0.20$', r'$\tau = 0.50$'])
同样的我们定义一个新类ImplicitEulerSchemeWithScipy
使用scipy
的算法:class ImplicitEulerSchemeWithScipy: def __init__(self, M, N, equation):self.eq = equationself.dt = self.eq.T / Mself.dx = self.eq.X / Nself.U = np.zeros((N+1, M+1))self.xArray = np.linspace(0,self.eq.X,N+1)self.U[:,0] = map(self.eq.ic, self.xArray)self.rho = self.eq.kappa * self.dt / self.dx / self.dxself.M = Mself.N = Ndef roll_back(self):for k in range(0, self.M):udiag = - np.ones(self.N-1) * self.rholdiag = - np.ones(self.N-1) * self.rhocdiag = np.ones(self.N-1) * (1.0 + 2. * self.rho)mat = np.zeros((3,self.N-1))mat[0,:] = udiagmat[1,:] = cdiagmat[2,:] = ldiagrhs = self.U[1:self.N,k]x = solve_banded((1,1), mat, rhs)self.U[1:self.N, k+1] = xself.U[0][k+1] = self.eq.bcl(self.xArray[0])self.U[self.N][k+1] = self.eq.bcr(self.xArray[-1])def mesh_grids(self):tArray = np.linspace(0, self.eq.T, M+1)tGrids, xGrids = np.meshgrid(tArray, self.xArray)return tGrids, xGrids
下面的代码,比较了两种做法的性能。可以看到仅仅简单的替代三对角矩阵算法,我们就获得了接近8倍的性能提升
import time startTime = time.time() loop_round = 10# 不使用scipy for k in range(loop_round):ht = HeatEquation(1.,X, T)scheme = ImplicitEulerScheme(M,N, ht)scheme.roll_back() endTime = time.time() print '{0:<40}{1:.4f}'.format('执行时间(s) -- 不使用scipy.linalg: ', endTime - startTime)# 使用scipy startTime = time.time() for k in range(loop_round):ht = HeatEquation(1.,X, T)scheme = ImplicitEulerSchemeWithScipy(M,N, ht)scheme.roll_back() endTime = time.time() print '{0:<40}{1:.4f}'.format('执行时间(s) -- 使用scipy.linalg: ', endTime - startTime)
4. 尾声
到这里为止,我们已经结束了偏微分方差差分格式的基础学习。这是一个很大的学科,这两天也只能做到“管中窥豹”,更多知识请移步“https://uqer.io/community/list”。但是有了以上的基础知识,读者已经有了足够的积累,可以处理一些金融工程中会实际遇到的方程。在下一天中,我们将把这两天学习到的知识运用到金融工程史上最重要的方程:Black - Scholes - Merton 偏微分方程。
量化分析师的Python日记【Q Quant兵器谱之偏微分方程2】相关推荐
- 量化分析师的Python日记 系列
量化分析师的Python日记 系列 转发,原作者 薛昆Kelvin 为方便学习,整理一下学习材料.持续更新. [第1天:谁来给我讲讲Python?] https://uqer.io/community ...
- 量化分析师的Python日记-CSDN公开课-专题视频课程
量化分析师的Python日记-7882人已学习 课程介绍 以完全初学者的角度入手来认识Python这个在量化领域日益重要的语言. 课程收益 课程先从介绍Python本身一些基本 ...
- 量化分析师的Python日记【Q Quant兵器谱之二叉树】
通过之前几天的学习,Q Quant们应该已经熟悉了Python的基本语法,也了解了Python中常用数值库的算法.到这里为止,小Q们也许早就对之前简单的例子不满意,希望能在Python里面玩票大的!O ...
- 量化分析师的Python日记【Q Quant 之初出江湖】
本篇中,作为Quant中的Q宗(P Quant 和 Q Quant 到底哪个是未来?),我们将尝试把之前的介绍的工具串联起来,小试牛刀. 您将可以体验到: 如何使用python内置的数学函数计算期权的 ...
- python量化分析系列(第一篇)_量化分析师的 Python 日记 [第 1 天:谁来给我讲讲 Python?]...
45 条回复 • 2016-05-25 11:10:23 +08:00 1 2015-04-08 21:42:42 +08:00 这里竟然有Quant 2 2015-04-08 22:49:51 +0 ...
- 量化分析师的Python日记【Q Quant兵器谱 -之偏微分方程1】
从今天开始我们将进入一个系列 -- 偏微分方程.作为这一系列的开篇,我们以热传导方差为引子,引出: 如何提一个偏微分方程的初边值问题: 利用差分格式将偏微分方程离散化: 显示差分格式: 显示差分格式的 ...
- 量化分析师的Python日记【Q Quant兵器谱之函数插值】
在本篇中,我们将介绍Q宽客常用工具之一:函数插值.接着将函数插值应用于一个实际的金融建模场景中:波动率曲面构造. 通过本篇的学习您将学习到: 如何在scipy中使用函数插值模块:interpolate ...
- 量化分析师的Python日记【Q Quant兵器谱之偏微分方程3的具体金融学运用】
欢迎来到 Black - Scholes - Merton 的世界!本篇中我们将把第11天学习到的知识应用到这个金融学的具体方程之上! import numpy as np import math i ...
- 量化分析师的python日记_量化分析师的Python日记【第1天:谁来给我讲讲Python?】...
"谁来给我讲讲Python?" 作为无基础的初学者,只想先大概了解一下Python,随便编个小程序,并能看懂一般的程序,那些什么JAVA啊.C啊.继承啊.异常啊通通不懂怎么办,于是 ...
最新文章
- MailKit帮助类
- Python3破冰人工智能,你需要掌握一些数学方法
- 【js细节剖析】通过=操作符为对象添加新属性时,结果会受到原型链上的同名属性影响...
- 区块链教程Fabric1.0源代码分析Chaincode(链码)体系总结
- php mysql 字段备注_MySQL下读取 表/字段 的说明备注信息
- 为什么机器人发展了几十年感觉还是没太大进展
- leetcode 88 Merge Sorted Array
- ASP.NET MVC中ViewData、ViewBag和TempData
- LeetCode之Two Sum II - Input array is sorted
- 商城商品购买数量增减的完美JS效果
- 有关SQL Server事务日志的十大文章
- Matlab计算空间权重矩阵(地理距离和经济地理距离)
- Python实现自己的分布式区块链视频教程-张敏-专题视频课程
- 打开php网页中木马,常见PHP网页木马
- Unity 之 月签到累计签到代码实现(ScriptableObject应用 | DoTween入场动画)
- 计算机网络安全 填空题,计算机网络安全技术选择填空复习题
- Google 机器学习术语表
- appcan中的微信分享与qq分享
- cortex m3/m4处理器的复位设计
- 微服务架构师封神之路09-Springboot多数据源,Hikari连接池和事务配置
热门文章
- Vondrak滤波及测试(python)
- 2021.10.27-28科研日志
- 绪论(p1-p2) author:run
- 用python画简单花瓣_花瓣网花瓣爬虫
- 几年后的 JavaScript 会是什么样子?
- VB.net MessageBox弹出的确认对话框点击确定按钮
- SymPy:如何用 Python 求解微积分
- Ultimaker2 3D打印机源文件在线公布
- 如何用python做考勤_【python爬虫教程 考勤】如何用Python实现一只小爬虫,爬取拉勾网...
- 一个领域中的红海和蓝海