共轭梯度法的简单直观理解

  • 参考资料
  • What is: 什么是共轭梯度法?(简单直观理解)
    • 共轭向量
    • 共轭方向
    • 误差与残差
    • 搜索方向的确定
    • 步长,或者说系数alpha
  • How to: 怎么用共轭梯度法?(完整算法)
  • Python numpy代码
  • Why: 为什么共轭梯度法能求解Ax=b?
    • 二次型
    • 将Ax=b问题转化为最优化问题
  • 拓展:改进——预处理的共轭梯度法

参考资料

本文是参考以下内容,结合自己的理解做的笔记。请尽量直接访问下述网页。

  1. 矩阵与数值计算(11)——共轭梯度法
  2. 共轭梯度法(一):线性共轭梯度
  3. 无痛版共轭梯度法介绍(更新到第五章)
  4. 共轭梯度法通俗讲义
    第4个资料尤其清晰完备,很多都是参考它的。

What is: 什么是共轭梯度法?(简单直观理解)

共轭梯度法可以看作是梯度下降法(又称最速下降法)的一个改进。

对梯度下降来说
x⃗i+1=x⃗i−α∇f\vec x_{i+1}=\vec x_i - \alpha\nabla f xi+1​=xi​−α∇f
其中α\alphaα控制了一步要走多远,因此被称为步长,在机器学习里面又称为学习率

梯度下降法x移动的方向正是函数f的负梯度方向,这代表了局部上f减小最快的方向。

但是局部上减小最快的方向并不代表全局上指向最终解的方向。所以梯度下降法会出现像醉汉下山一样走出zig-zag的路线。如下图

图1 梯度下降法在2维解空间(也就是解向量只有两个维度)走出的路径示意图。
注:假如A是正定对称阵,其2维解空间必定是椭圆的。

为什么会走出这一Z形线呢?因为梯度下降的方向恰好与f垂直,也就是说和等高线垂直。沿着垂直于等高线的方向,一定能让函数减小,也就是最快地下了一个台阶。但是最快下台阶并不意味着最快到达目标位置(即最优解),因为你最终的目标并不是直对着台阶的。

为了修正这一路线,采用另一个方向:即共轭向量的方向。

我们先暂且给出共轭梯度法最后的形式,方便字母的定义:
x⃗i+1=x⃗i−αd⃗i\vec x_{i+1}=\vec x_i - \alpha \vec d_i xi+1​=xi​−αdi​
对照梯度下降法,每次向下走的方向不是梯度了,而是专门的一个方向d⃗\vec dd。除此之外和梯度下降法几乎一样。

在推进下一步之前,我们来看看什么是向量共轭。

共轭向量

下面先简要介绍共轭向量

所谓共轭向量,在数学上即:
piTApj=0p_i^TAp_j=0 piT​Apj​=0

其中A是一个对称正定矩阵。
pip_ipi​和pjp_jpj​是一对共轭的向量

可见,共轭是正交的推广化,因为向量正交的定义为:
piTpj=0p_i^Tp_j=0 piT​pj​=0
共轭比正交中间只多了个矩阵A,而矩阵的几何意义正是对一个向量进行线性变换(可见Gilber Strang的线代公开课)。因此共轭向量的意思就是一个向量经过线性变换(缩放剪切和旋转)之后与另一个向量正交。

共轭方向

言归正传,如何找到一个更好的方向呢?我们首先可以看看最完美的方向是什么样的。

下面这张图展示的就是最完美的方向。图中向量e代表的是误差。向量d就是方向向量。

误差e即当前迭代所得到的解与精确解的差值:
e⃗i=x⃗i−x⃗∗\vec e_i=\vec x_i- \vec x^* ei​=xi​−x∗

可惜我们并不能找到误差向量e,因为我们不知道精确解。

那么退而求其次,我们就找误差向量的共轭向量。因为图中可以看出,误差向量是与方向向量垂直的,即正交。刚才说了,共轭就是正交的推广。一个向量乘以一个矩阵之后与另一个方向正交,就是共轭。

即找到
d⃗TAe⃗=0\vec d ^T A \vec e =0 dTAe=0

但是这个公式里面仍然含有e,我们必须想办法去掉它,换成一个我们可以计算的量。

在推进下一步之前,我们先来看看误差与残差这两个概念的区别。

误差与残差

前面写道:

误差error 即当前迭代所得到的解与精确解的差值:
e⃗i=x⃗i−x⃗∗\vec e_i=\vec x_i- \vec x^* ei​=xi​−x∗

但是这种定义显然是没法直接用的,因为我们不知道精确解x∗x^*x∗

那么退而求其次,我们想到,当误差收敛为0的时候,必然满足方程Ax=b,那么由此就可以定义出残差residual

r⃗i=b⃗−Ax⃗i\vec r_i=\vec b-A\vec x_i ri​=b−Axi​

这个定义的精妙之处在于,它定义了Ax接近b的距离,当距离为0的时候,恰好就是精确解。但是又能避开精确解本身。

在实际的程序中,我们还常常定义相对残差,即上一步迭代和这一步迭代的残差的相对变化率,这里就不再赘述。

显然,误差和残差之间就差了一个矩阵A,他们两者的关系是这样的:

r⃗i=b⃗−A(e⃗i+x⃗∗)=b⃗−Ax⃗∗−Ae⃗i=−Ae⃗i\vec r_i=\vec b - A(\vec e_i+\vec x^*)=\vec b - A \vec x^* -A\vec e_i = -A\vec e_i ri​=b−A(ei​+x∗)=b−Ax∗−Aei​=−Aei​

可见除了A,还多了个负号。

搜索方向的确定

言归正传,利用残差,我们终于可以把误差e给替换掉了:
于是前面的式子就变成了
d⃗iTAe⃗i=−d⃗iTr⃗i=0\vec d_i ^T A \vec e_i =-\vec d_i ^T \vec r_i=0 diT​Aei​=−diT​ri​=0

那么,这告诉我们:方向向量d,正是与残差向量正交的方向!

接下来我们只需要构建一个与残差正交的向量就可以了。这部分内容是由施密特正交化(更严谨一点的说法是共轭格莱姆-施密特过程)完成的。由于只是一个计算的方法,对概念的理解没有帮助,所以我们跳过,直接给出结论。

每一步搜索方向的时候,这一方向与残差以及前一步的方向有关
d⃗i+1=r⃗i+1+βi+1d⃗i\vec d_{i+1} = \vec r_{i+1} +\beta_{i+1} \vec d_i di+1​=ri+1​+βi+1​di​
其中系数β\betaβ可以这样计算:
βi+1=r⃗i+1Tr⃗i+1r⃗iTr⃗i\beta_{i+1} = \frac{ \vec r_{i+1}^T \vec r_{i+1} } {\vec r_{i}^T \vec r_{i} } βi+1​=riT​ri​ri+1T​ri+1​​

这个系数beta其实很好记,因为分子就是残差的内积(下一步),分母也是残差的内积(这一步)。
或者说分子就是残差长度的平方(下一步),分母也是残差长度的平方(这一步)。(向量自己和自己的内积就是它的长度)

从另一个角度额外补充一点理解:
每次走的方向恰好是与残差正交的,这意味着:
每走一步恰好能消除残差的一个方向!
所以,当消除了残差所有投影方向上的值,那么就消除了整个残差!

步长,或者说系数alpha

实际上,还有一点没有解决,就是系数α\alphaα怎么算?

这点的解释我们以后有机会再说,直接给出结论。
αi=r⃗i+1Tr⃗i+1d⃗iTAd⃗i\alpha_i = \frac{ \vec r_{i+1}^T \vec r_{i+1} } {\vec d_{i}^T A\vec d_{i} } αi​=diT​Adi​ri+1T​ri+1​​
这个alpha的分子和beta一样,就是残差的内积。分母则是方向向量在乘以矩阵A之后的内积。

How to: 怎么用共轭梯度法?(完整算法)

  1. 设定初值
    d⃗0=r⃗0=b⃗−Ax⃗0\vec d_0=\vec r_0 = \vec b - A \vec x_0 \\ d0​=r0​=b−Ax0​

  2. 计算系数alpha
    αi=r⃗i+1Tr⃗i+1d⃗iTAd⃗i\alpha_i = \frac{ \vec r_{i+1}^T \vec r_{i+1} } {\vec d_{i}^T A\vec d_{i} } αi​=diT​Adi​ri+1T​ri+1​​

  3. 迭代一步(向下走一步)
    x⃗i+1=x⃗i−αid⃗i\vec x_{i+1}=\vec x_i - \alpha_i \vec d_i xi+1​=xi​−αi​di​

  4. 计算残差(此处已经被修改,原文没有被50整除那一个公式 2022-05-27)
    如果迭代次数可以被50整除
    r⃗i+1=r⃗i−αiAx⃗\vec r_{i+1}=\vec r_i - \alpha_i A\vec x ri+1​=ri​−αi​Ax
    否则
    r⃗i+1=r⃗i−αiAd\vec r_{i+1}=\vec r_i - \alpha_i A d ri+1​=ri​−αi​Ad

  5. 计算系数beta
    βi+1=r⃗i+1Tr⃗i+1r⃗iTr⃗i\beta_{i+1} = \frac{ \vec r_{i+1}^T \vec r_{i+1} } {\vec r_{i}^T \vec r_{i} } βi+1​=riT​ri​ri+1T​ri+1​​

  6. 计算搜索方向d⃗\vec dd
    d⃗i+1=r⃗i+1+βi+1d⃗i\vec d_{i+1} = \vec r_{i+1} +\beta_{i+1} \vec d_i di+1​=ri+1​+βi+1​di​

重复2~6,直到残差足够小

Python numpy代码

import numpy as np
import scipy.linalg as sl
import matplotlib.pyplot as pltnn=4 #矩阵的规模 FIXME: 当规模>5的时候会出现震荡,为什么?
accuracy = 1e-6
#使用共轭梯度法, A矩阵有两个条件:1. 正定(特征值全为正数) 2.对称
A = sl.pascal(nn, exact=False) # A是对称正定矩阵, 10阶帕斯卡矩阵, exact=False示用float元素而非默认的uint
b = np.arange(1., nn+1., 1.)
x0 = np.array([2.0]*nn) #x0def Conjugate_Gradient_Method(A,b,x0):#1. Set initial valuex = x0r = b - matrixVecProd(A, x)d = rrr = vecVecProd(r, r) #在计算beta的时候可以复用iter = 0relativeResidual = 0.1while( relativeResidual > accuracy or iter <1): #2. Compute alphaAd = matrixVecProd(A, d) # 在计算r_new和alpha的时候可以复用alpha =  rr / vecVecProd(d, Ad)#3. step forwardx = x + alpha * d#4. compute residualif iter % 50 == 0 :r_new = b - matrixVecProd(A, x)else :r_new = r - alpha * Ad#5. compute betarr_new = vecVecProd(r_new, r_new) beta = rr_new / rrrr = rr_new#6. compute search directiond = r_new + beta * diter += 1 relativeResidual = np.linalg.norm(r_new) / np.linalg.norm(r)r = r_newprint("iter",iter, "relativeResidual",relativeResidual)return xdef matrixVecProd(A, vec):res = np.dot(A, vec)return resdef vecVecProd(vec1, vec2):res = np.dot(vec1, vec2)return res# ----------------TEST-------------
def TEST_A(A,b):print(A)#A矩阵有两个条件:1. 正定(特征值全为正数) 2.对称eig = np.linalg.eig(A)print("eig=",eig[0]) #eig[0]取的是特征值print("b",b)
# ----------------ENDTEST-------------def main():res = Conjugate_Gradient_Method(A, b, x0)print("-------------numerical result-------------------")print(res)print("-------------accurate result-------------------")accRes=np.dot(np.linalg.inv(A), b)print(accRes)if __name__ == "__main__":# TEST_A(A,b) # 可以先打印出来看看main()

这个代码仍然是有问题的,主要是矩阵规模大的时候就会震荡,我也不清楚为什么。
这里照抄一下刘天添课上的算法

Why: 为什么共轭梯度法能求解Ax=b?

说了这么多,其实有一个关键问题没有讲,那就是:为什么共轭梯度法能求解Ax=b?

按理说,共轭梯度法是函数最优化的方法,怎么就扯上了求解Ax=b了呢?

实际上使用共轭梯度法的两个条件

  1. A是对称的
  2. A是正定的

也和这个原理有关。

数学家求解问题的思路是:把不会的问题转化成会的问题,再套用会的问题的思路求解问题。

为了说明这一点,我们要从线性代数的二次型入手。我们可以先复习一下二次型,了解一下它是什么。

二次型

二次型就是关于向量的二次函数。

我们高中学过的二次函数通用表达式为
f(x)=ax2+bx+cf(x) = a x^2 +bx+c f(x)=ax2+bx+c

如果把其中的x替换为向量x,并且把a x^2 替换为
x^T A x 就得到了

f(x)=xTAx+bx+cf(x) = x^T A x +bx+c f(x)=xTAx+bx+c

这就是二次型。

二次型求导得到
f′(x)=12(Ax+ATx)+bf'(x) = \frac{1}{2}( A x + A^T x)+b f′(x)=21​(Ax+ATx)+b

将Ax=b问题转化为最优化问题

我们本来求解的是
Ax=bA\mathbf x=\bf b Ax=b

这个问题被转化为了求某个函数的导数等于0的问题,即驻值问题。

我们设这个函数为g(x)。我们的问题即:

g′(x)=0x∗=argminxg(x)g'(x)=0\\ x^*=argmin_x g(x) g′(x)=0x∗=argminx​g(x)
argmin_x的意思就是我们求取最小值的时候的x,而不是最小值本身。

这个x∗x^*x∗就是最终解。

那么怎么联系到Ax=b呢?

我们只要改造这个函数g,让它的导数恰好就是Ax-b=0就好了!!
而这个函数,恰好就是二次型函数!


g′(x)=Ax−bg'(x)=Ax-b g′(x)=Ax−b

于是求最小值得问题就能够被转化为求Ax=b的问题!

这里有个小小的瑕疵:
实际上,二次型g(x)的导数是
g′(x)=1/2(AT+A)x−bg'(x)=1/2 (A^T+A)x-b g′(x)=1/2(AT+A)x−b

所以我们就要限定AT=AA^T=AAT=A,即限定A是对称的。这就是第一个条件的由来!

to be continued
2022-05-20

拓展:改进——预处理的共轭梯度法

共轭梯度法的简单直观理解相关推荐

  1. 简单直观理解形态学中的开运算和闭运算

    开运算是先腐蚀,再膨胀.闭运算是先膨胀,再腐蚀. 注意上图,开运算去除了图像中比较小的点,闭运算将图中的多个圆变成了一个整体. 具体matlab代码如下: close all;clear all;cl ...

  2. 最简单直观理解为什么补码=反码+1

    假设前提:在一个字节里 原码+补码一定是溢出(可以想象一个满圆) 又反码是原码取反的,所以反码+原码一定是最大值(绝对值255),那么再加1就是溢出,所以得出反码+1就是补码. 总结:原码和补码的关系 ...

  3. RNN循环神经网络的直观理解:基于TensorFlow的简单RNN例子

    RNN 直观理解 一个非常棒的RNN入门Anyone Can learn To Code LSTM-RNN in Python(Part 1: RNN) 基于此文章,本文给出我自己的一些愚见 基于此文 ...

  4. (通过简单直观的推导理解卡尔曼基础)Understanding the Basis of the Kalman Filter Via a Simple and Intuitive Derivation

    通过简单直观的推导理解卡尔曼基础 RELEVANCE PREREQUISITES PROBLEM STATEMENT SOLUTIONS REFERENCE 本文提供了卡尔曼滤波器的简单直观的推导,目 ...

  5. 梯度的直观理解_关于梯度、旋度和散度的直观理解

    关于梯度.旋度和散度的直观理解 散度为零,说明是无源场:散度不为零时,则说明是有源场(有正源或负源) 若你的场是一个流速场,则该场的散度是该流体在某一点单位时间流出单位体积的净流量. 如果在某点,某场 ...

  6. SVM支持向量机【直观理解】

    转载文章:https://baijiahao.baidu.com/s?id=1607469282626953830&wfr=spider&for=pc 如果你曾经使用机器学习解决分类问 ...

  7. spark任务shell运行_《Spark快速大数据分析》- 根据简单例子理解RDD

    1. RDD简介 RDD,弹性分布式数据集(Resiliennt Distributed Datasets),是Spark中最重要的核心概念,是Spark应用中存储数据的数据结构. RDD 其实就是分 ...

  8. 深度学习与计算机视觉(四)反向传播及其直观理解

    四.反向传播及其直观理解 4.1 引言 问题描述和动机: 大家都知道,其实我们就是在给定的图像像素向量x和对应的函数f(x)f(x)f(x),然后我们希望能够计算fff在x上的梯度∇f(x)" ...

  9. delphi 发送网络消息_《新手学习ISO网络模型》(1)如何直观理解物理层?

    新手向,以入门为主,建立对物理层的直观理解 网络就是一组互相连接的通信设备.如何实现网络可以让两台计算机传达消息. 协议:决定两个人或两台设备交流信息都要遵守的一个规则. 我们可以通过制定自己的协议来 ...

  10. 简单地理解 Python 的装饰器

    关于decorator说的比较透彻,作者是一位很善于讲课的人. 本文系转载,作者:0xFEE1C001 原文链接 www.lightxue.com/understand-python-decorato ...

最新文章

  1. 算法:三角形最小路径和
  2. Android LayoutInflater源码解析:你真的能正确使用吗?
  3. c/c++如何正确使用结构体?
  4. linux idea 启动报错StartupAbortedException: Fatal error initializing plugin idea.plugin.protoeditor
  5. 算法笔记_028:字符串转换成整数(Java)
  6. python以什么表示代码层次_python 中几个层次的中文编码.md
  7. Java标识符和关键字(static,final,abstract,interface)
  8. 利用html sessionStorge 来保存局部页面在刷新后回显,保留
  9. mysql高可用方案之主从架构(master-slave)
  10. 微软十二月补丁星期二修复58个漏洞
  11. oracle取月去0,Oracle取月份,不带前面的0
  12. 试题 基础练习 特殊回文数
  13. 【Android 安装包优化】WebP 图片格式 ( WebP 图片格式简介 | 使用 Android Studio 转换 WebP 图片格式 )
  14. redis应用之安装配置介绍
  15. CentOS7使用firewalld打开关闭防火墙与端口
  16. FCFS,SJF以及PSA进程调度算法效率的比较
  17. Beta冲刺-第四天
  18. 全球与中国废电池回收市场现状及未来发展趋势2022
  19. GetLastError 错误返回码
  20. 江南大学大作业答案 计算机网络,江南大学大作业答案 计算机网络

热门文章

  1. 使用Excel进行线性规划
  2. php 富文本编辑器,曾经用过的十大富文本编辑器
  3. 【AD】Altium Designer 原理图的绘制
  4. C#生成Excel出现8000401a的错误的另一种解决办法。
  5. dmx512如何帧同步_stm32实现DMX512协议发送与接收(非标)
  6. 课程设计(飞机订票系统) 超全
  7. 杂项-数学软件:MATLAB
  8. geetest极验空间推理验证码破解与研究
  9. 用友nccloud 虚拟机
  10. 抽象工厂模式类图及代码示例