目录

  • 1 引例
  • 2 牛顿迭代算法求根
  • 3 牛顿迭代优化
  • 4 代码实战:Logistic回归

1 引例

给定如图所示的某个函数,如何计算函数零点 x 0 x_0 x0​?

在数学上我们如何处理这个问题?

最简单的办法是解方程 f ( x ) = 0 f(x)=0 f(x)=0,在代数学上还有著名的零点判定定理

如果函数 y = f ( x ) y= f(x) y=f(x)在区间 [ a , b ] [a,b] [a,b]上的图象是连续不断的一条曲线,并且有 f ( a ) ⋅ f ( b ) < 0 f(a)·f(b)<0 f(a)⋅f(b)<0,那么函数 y = f ( x ) y= f(x) y=f(x)在区间 ( a , b ) (a,b) (a,b)内有零点,即至少存在一个 c ∈ ( a , b ) c∈(a,b) c∈(a,b),使得 f ( c ) = 0 f(c)=0 f(c)=0,这个 c c c也就是方程 f ( x ) = 0 f(x)= 0 f(x)=0的根。

然而,数学上的方法并不一定适合工程应用,当函数形式复杂,例如出现超越函数形式;非解析形式,例如递推关系时,精确的方程解析一般难以进行,因为代数上还没发展出任意形式的求根公式。而零点判定定理求解效率也较低,需要不停试错。

因此,引入今天的主题——牛顿迭代法,服务于工程数值计算。

2 牛顿迭代算法求根

记第 k k k轮迭代后,自变量更新为 x k x_k xk​,令目标函数 f ( x ) f(x) f(x)在 x = x k x=x_k x=xk​泰勒展开:

f ( x ) = f ( x k ) + f ′ ( x k ) ( x − x k ) + o ( x ) f\left( x \right) =f\left( x_k \right) +f'\left( x_k \right) \left( x-x_k \right) +o(x) f(x)=f(xk​)+f′(xk​)(x−xk​)+o(x)

我们希望下一次迭代到根点,忽略泰勒余项,令 f ( x k + 1 ) = 0 f(x_{k+1})=0 f(xk+1​)=0,则

x k + 1 = x k − f ( x k ) f ′ ( x k ) x_{k+1}=x_k-\frac{f(x_k)}{f'(x_k)} xk+1​=xk​−f′(xk​)f(xk​)​

不断重复运算即可逼近根点。

在几何上,上面过程实际上是在做 f ( x ) f(x) f(x)在 x = x k x=x_k x=xk​处的切线,并求切线的零点,在工程上称为局部线性化。如图所示,若 x k x_k xk​在 x 0 x_0 x0​的左侧,那么下一次迭代方向向右。

若 x k x_k xk​在 x 0 x_0 x0​的右侧,那么下一次迭代方向向左。

3 牛顿迭代优化

将优化问题转化为求目标函数一阶导数零点的问题,即可运用上面说的牛顿迭代法。

具体地,记第 k k k轮迭代后,自变量更新为 x k x_k xk​,令目标函数 f ( x ) f(x) f(x)在 x = x k x=x_k x=xk​泰勒展开:

f ( x ) = f ( x k ) + f ′ ( x k ) ( x − x k ) + 1 2 f ′ ′ ( x k ) ( x − x k ) 2 + o ( x ) f\left( x \right) =f\left( x_k \right) +f'\left( x_k \right) \left( x-x_k \right) +\frac{1}{2}f''\left( x_k \right) \left( x-x_k \right) ^2+o(x) f(x)=f(xk​)+f′(xk​)(x−xk​)+21​f′′(xk​)(x−xk​)2+o(x)

两边求导得

f ′ ( x ) = f ′ ( x k ) + f ′ ′ ( x k ) ( x − x k ) f'\left( x \right) =f'\left( x_k \right) +f''\left( x_k \right) \left( x-x_k \right) f′(x)=f′(xk​)+f′′(xk​)(x−xk​)

令 f ′ ( x k + 1 ) = f ′ ( x k ) + f ′ ′ ( x k ) ( x k + 1 − x k ) = 0 f'\left( x_{k+1} \right) =f'\left( x_k \right) +f''\left( x_k \right) \left( x_{k+1}-x_k \right) =0 f′(xk+1​)=f′(xk​)+f′′(xk​)(xk+1​−xk​)=0,从而得到

x k + 1 = x k − f ′ ( x k ) f ′ ′ ( x k ) x_{k+1}=x_k-\frac{f'\left( x_k \right)}{f''\left( x_k \right)} xk+1​=xk​−f′′(xk​)f′(xk​)​

对于向量 x = [ x 1 x 2 ⋯ x d ] T \boldsymbol{x}=\left[ \begin{matrix} x_1& x_2& \cdots& x_d\\\end{matrix} \right] ^T x=[x1​​x2​​⋯​xd​​]T,将上述迭代公式推广为

x k + 1 = x k − [ ∇ 2 f ( x k ) ] − 1 ∇ f ( x k ) {\boldsymbol{x}_{k+1}=\boldsymbol{x}_k-\left[ \nabla ^2f\left( \boldsymbol{x}_k \right) \right] ^{-1}\nabla f\left( \boldsymbol{x}_k \right) } xk+1​=xk​−[∇2f(xk​)]−1∇f(xk​)

其中 ∇ 2 f ( x k ) \nabla ^2f\left( \boldsymbol{x}_k \right) ∇2f(xk​)是Hessian矩阵,当其正定时可以保证牛顿优化算法往 减小的方向迭代

牛顿法的特点如下:

① 以二阶速率向最优点收敛,迭代次数远小于梯度下降法,优化速度快;

梯度下降法的解析参考图文详解神秘的梯度下降算法原理(附Python代码)

②学习率为 [ ∇ 2 f ( x k ) ] − 1 \left[ \nabla ^2f\left( \boldsymbol{x}_k \right) \right] ^{-1} [∇2f(xk​)]−1,包含更多函数本身的信息,迭代步长可实现自动调整,可视为自适应梯度下降算法;

③ 耗费CPU计算资源多,每次迭代需要计算一次Hessian矩阵,且无法保证Hessian矩阵可逆且正定,因而无法保证一定向最优点收敛。

在实际应用中,牛顿迭代法一般不能直接使用,会引入改进来规避其缺陷,称为拟牛顿算法簇,其中包含大量不同的算法变种,例如共轭梯度法、DFP算法等等,今后都会介绍到。

4 代码实战:Logistic回归

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from Logit import Logit'''
* @breif: 从CSV中加载指定数据
* @param[in]: file -> 文件名
* @param[in]: colName -> 要加载的列名
* @param[in]: mode -> 加载模式, set: 列名与该列数据组成的字典, df: df类型
* @retval: mode模式下的返回值
'''
def loadCsvData(file, colName, mode='df'):assert mode in ('set', 'df')df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)if mode == 'df':return dfif mode == 'set':res = {}for col in colName:res[col] = df[col].valuesreturn resif __name__ == '__main__':# ============================# 读取CSV数据# ============================csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')dataY = loadCsvData(csvPath, ["好瓜"], 'df')label = np.array([1 if i == "是" else 0for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))])# ============================# 绘制样本点# ============================line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])mpl.rcParams['font.sans-serif'] = [u'SimHei']plt.title('对数几率回归模拟\nLogistic Regression Simulation')plt.xlabel('density')plt.ylabel('sugarRate')plt.scatter(dataX['密度'][label==0],dataX['含糖率'][label==0],marker='^',color='k',s=100,label='坏瓜')plt.scatter(dataX['密度'][label==1],dataX['含糖率'][label==1],marker='^',color='r',s=100,label='好瓜')# ============================# 实例化对数几率回归模型# ============================logit = Logit(dataX, label)# 采用牛顿迭代法logit.logitRegression(logit.newtomMethod)line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]plt.plot(line_x, line_y, 'g-', label="牛顿迭代法")# 绘图plt.legend(loc='upper left')plt.show()

其中更新权重代码为

    '''* @breif: 牛顿迭代法更新权重* @param[in]: None* @retval: 优化参数的增量dw'''def newtomMethod(self):wTx = np.dot(self.w.T, self.X).reshape(-1, 1)p = Logit.sigmod(wTx)dw_1 = -self.X.dot(self.y - p)dw_2 = self.X.dot(np.diag((p * (1 - p)).reshape(self.N))).dot(self.X.T)dw = np.linalg.inv(dw_2).dot(dw_1)return dw


Pytorch深度学习实战1-6:图解牛顿迭代法,牛顿不止力学三定律相关推荐

  1. 实战例子_Pytorch官方力荐新书《Pytorch深度学习实战指南》pdf及代码分享

    PyTorch是目前非常流行的机器学习.深度学习算法运算框架.它可以充分利用GPU进行加速,可以快速的处理复杂的深度学习模型,并且具有很好的扩展性,可以轻松扩展到分布式系统.PyTorch与Pytho ...

  2. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

  3. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  4. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  5. PyTorch深度学习实战(5)——计算机视觉基础

    PyTorch深度学习实战(5)--计算机视觉基础 0. 前言 1. 图像表示 2. 将图像转换为结构化数组 2.1 灰度图像表示 2.2 彩色图像表示 3 利用神经网络进行图像分析的优势 小结 系列 ...

  6. PyTorch深度学习实战:从新手小白到数据科学家电子书

    作者:张敏 著 出版社:电子工业出版社 ISBN:9787121388293 出版时间:2020-08-01 PyTorch深度学习实战:从新手小白到数据科学家

  7. Pytorch深度学习实战教程:UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...

  8. Pytorch 深度学习实战:视频自动打码

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 人脸识别 人脸识别是一门比较成熟的技术. 它的身影随处可见,刷脸支 ...

  9. PyTorch 深度学习实战 | 基于生成式对抗网络生成动漫人物

    生成式对抗网络(Generative Adversarial Network, GAN)是近些年计算机视觉领域非常常见的一类方法,其强大的从已有数据集中生成新数据的能力令人惊叹,甚至连人眼都无法进行分 ...

最新文章

  1. 出现module ‘xgboost‘ has no attribute ‘DMatrix‘的临时解决方法
  2. cbow word2vec 损失_Skip-gram和CBOW知识点
  3. 软件工程—让软件包自带commit id
  4. go语言渐入佳境[6]-operator运算符
  5. java invocationtarget,启动工程报java.lang.reflect.InvocationTargetException的解决详解
  6. linux内核驱动中对字符串的操作【转】
  7. codeblocks(其它软件)修改后缀文件的打开默认方式
  8. php 函数传值_传址_函数参数,php函数的传值与传址
  9. 【javaEE】——计算机基础知识(进程的理解和通信)01
  10. 计算机网络中的所谓资源是指硬件软件资源,计算机网络试题..doc
  11. 美国服务器用于外贸建站有哪些好处
  12. python学习的读书路线
  13. 解决Windows Update错误“80072EFD”
  14. (筆記) 如何在字串中從指定字元抓到指定字元(pointer版)? (C/C++) (C)
  15. 动态规划:钢条切割问题
  16. LoRaWAN协议-物理层(PHY)详解
  17. App 抓包利器:Charles 以及 App 爬虫心得
  18. 墨菲定律|马太效应|破窗理论|蝴蝶效应
  19. 2022年3月 python一级 程序题 【买本子和画三角形】
  20. 网络同步在游戏历史中的发展变化(三)—— 状态同步的发展历程与基本原理(上)...

热门文章

  1. VUE实战--网易云音乐
  2. 在“芯片庭院”培育一颗多核异构 RISC-V SOC种子
  3. Excel文件解析性能对比(POI,easyexcel,xlsx-streamer)
  4. 查看oracle11g的企业管理器(OEM)
  5. vivo X Fold和OPPO Find N
  6. php 获取array的长度_php中获取数组长度的方法
  7. 实习僧网站字体反爬破解思路及步骤分享
  8. 机器学习从入门到创业手记-处理数据的乐趣在于挖掘
  9. 激光雷达错位拼接技术
  10. 使用Java SE8 Streams 处理数据,Part 2