【MLDL】logistics regression理解
以前有学过linear classification、linear regression和logistics regression,这次做一下总结,并主要推导一下交叉熵损失函数的由来和梯度下降法。
一、概述
开头先祭出林轩田老师讲义中的一张图
PLA、Linear Regression到logistics regression的区别。
误差函数由0/1误差演变为均方误差到交叉熵误差。
1.1 PLA/Pocket
PLA是针对线性可分的数据,进行二分类,使用0/1误差,初始化权重,然后迭代更新,当有一个分类错误点时,就纠正权重,Wt+1=Wt+Yn(t)*Xn(t),直到没有错误为止。
后来为了处理非线性可分数据,引入pocket,不再是找那个没有分类错误的权重,而是在迭代过程中记录每次权重放错的次数,经过足够多的权重后,去犯错最后的那个权重作为结果。
1.2 Linear regression
linear regression主要可以用来解决预测银行卡额度问题、预测房价问题等。使用均方误差。
暂时先不做过多说明。
直接进入logistics regression。
二、logistics regression
2.1 基本介绍
当我们在做预测心脏病是否复发的问题时,我们不可能给出一个是或否的回答,只能说,有多少的概率会复发。但是,我们的训练数据只有复发或不复发两种,而我们希望拿到的训练数据是有概率的。
这样就引入了logistics 函数,通过一个映射,将其转换为0-1之间的数,用来表示概率。
logistics 函数: f(x)=11+e−xf(x)=11+e−xf(x)= \frac{1}{1+e^{-x}}
于是,得到假设函数:
那我们如何来优化这个假设函数呢,使用什么样的误差函数呢?这里就引入了交叉熵损失函数。
2.2 推导交叉熵损失函数
假设我们有这样一堆数据,
我们的目标函数产生这个数据集的概率为:
P(D)=p(x1O)p(x2X)p...p(xNX)P(D)=p(x1O)p(x2X)p...p(xNX)P(D)=p(x1O)p(x2X)p...p(x_NX)
公式中,大写O为正分类O,大写X为负分类X
由于我们知道,已知数据x1x1x_1,产生O 的可能性就是我们的目标函数f(x),索所以可得:
由条件概率的公式可得:
P(B|A)=P(AB)P(A)P(B|A)=P(AB)P(A)P(B|A) = \frac{ P(AB)}{P(A)}
所以,产生数据集D的概率公式可表示为:
P(D)=p(x1)f(x1)∗p(x2)(1−f(x2))∗...∗p(xN)(1−f(xN))P(D)=p(x1)f(x1)∗p(x2)(1−f(x2))∗...∗p(xN)(1−f(xN))P(D)=p(x_1)f(x_1)* p(x_2)(1-f(x_2))*...*p(x_N)(1-f(x_N))
但我们并不知道目标函数f,我们只能通过假设函数h,让其去逼近f,这样,我们可以推测,假设函数h产生数据集D的概率与目标函数f产生数据集D的概率逼近。
于是得到:
P(D)=p(x1)h(x1)∗p(x2)(1−h(x2))∗...∗p(xN)(1−h(xN))P(D)=p(x1)h(x1)∗p(x2)(1−h(x2))∗...∗p(xN)(1−h(xN))P(D)=p(x_1)h(x_1)* p(x_2)(1-h(x_2))*...*p(x_N)(1-h(x_N))
由于logistics函数的特性,1−h(x)=h(−x)1−h(x)=h(−x)1-h(x)=h(-x),则
P(D)=p(x1)h(x1)∗p(x2)h(−x2)∗...∗p(xN)h(−xN)P(D)=p(x1)h(x1)∗p(x2)h(−x2)∗...∗p(xN)h(−xN)P(D)=p(x_1)h(x_1)* p(x_2)h(-x_2)*...*p(x_N)h(-x_N)
由于p(x)这一项对我们的概率没有影响,可以去掉,变为:
P(D)=h(x1)∗h(−x2)∗...∗h(−xN)P(D)=h(x1)∗h(−x2)∗...∗h(−xN)P(D)=h(x_1)* h(-x_2)*...*h(-x_N)
将里面的正负号去掉,可以添加ynyny_n,代表二分类的0/1。
于是变为:
P(D)=∏Nn=0h(ynxn)P(D)=∏n=0Nh(ynxn)P(D)=\prod_{n=0}^N h(y_nx_n)
我们求最优的假设函数时,即是从假设空间H中选一个产生D概率最高的函数,在数学上,称之为似然,即求最大似然。
于是,我们的目标现在变为:
而由于我们现在在做logistics,是求一个权值W,可以将W替代上图中的h,将有关W的式子带入:
于是得到:
但是这个式子我们不好处理连乘,于是加一个去对数,于是变为:
于是,将ln放进去后,连乘就变成了连加。于是变为:
maxwmaxwmax_w ∑Nn=0lnθ(ynwTxn)∑n=0Nlnθ(ynwTxn)\sum_{n=0}^N ln \theta (y_nw^Tx_n)
但是我们不想求max,求min比较容易,可以添加一个负号,并添加1/N,比较容易好计算,于是变为:
又由于θ(s)=11+e−sθ(s)=11+e−s \theta(s)= \frac {1}{1+e^{-s}} ,
所以上式可以化简为:
最终,得到交叉熵损失函数:
这就是交叉熵损失函数的由来。
2.3 优化方法:GD/SGD
有了损失函数后,我们就要对这个损失函数进行优化,尽可能降低这个损失,这里就是用到了经典的优化方法:梯度下降法。
我们得到损失函数后,
我们发现这个损失函数是连续可微凸函数,则只需找使梯度为0的权值w就可。即
Ein很容易求导,只需要运用链式求导法则即可,得到梯度为:
这样只要求解梯度为0的时候就可以。
我们可以看出,在梯度公式中,θθ\theta函数如果全部为0,则梯度会等于0,所以知道,当 −ynw′xn−ynw′xn-y_n w^{'}x_n 趋向于负无穷大时,其 θ(−ynw′xn)θ(−ynw′xn)\theta(-y_nw^{'}x_n) 会等于0,这就意味着,ynw′xnynw′xny_n w^{'}x_n要趋向于正无穷,这意味着所有的ynyny_n 与 w′xnw′xnw^{'}x_n同号,即数据集D线性可分,才会发生这种情况。
但如果不是线性可分呢,这种方法求梯度等于0就不可行。
这个梯度不是一个线性的方程式,直接求解梯度等于0会比较困难。
那这怎么办?
我们想到以前用过得PLA,
每次碰到错误分类时,就修正W,即
可以改成下面这种形式:
如果是一项错误项,w需要修正,取符号那里为1,可以进行修正,如果是正确的,取符号那里为0,不必进行修正,这样就可以把两种情况统一,即迭代时随机取一个样本进行更新。
进一步地,我们可以把这个公式总结为:
在更新时,一个是更新的方向,另一个是更新时要走多大一步。如下图:
进而地,我们可以用这种方法来得到logistics regression中的梯度为0的权重。
每次迭代时,选取合适的步长,合适的方向,这样得到最后想要的w,这种方法就是梯度下降法。
在PLA中,方向是我们错误修正的方向,步长是1。
那在logistics regression中,如何选择步长和方向呢?
那对于这种平滑的损失函数,我们可以怎么做呢?我们知道Ein长得就像一个山谷的形状,只要找到谷底的地方,就可以。直观上理解,就是迈着合适的步子,沿着最快到达谷底的方向。
为了推导出我们的v和ηη\eta,我们使用控制变量法。
首先推导v(滚下去的方向),则假设v的长度||v||=1,将走的长度都放到ηη\eta中去。
那我怎么找到这个最好的方向呢,假设我很贪心,当然是沿着最陡的方向往下走。即在方圆ηη\eta的空间中,我们沿着最陡的方向走。
但找这个最陡方向仍然很困难,容易做的是求解线性的问题。所以,我们将这个式子改为对于变量v是线性的。
对于弯弯曲曲的曲线,我们只看一小段的话,那他跟线段是一样的。就可以表示出一个线性关系。这里用到的思想就是泰勒展开式。那么上式就变成(如果ηη\eta够小的话):
此时,我们的问题变为一个近似线性的问题。
在这个式子中,只有v是未知的。我们可以把多余的内容忽略掉,Ein(wt)Ein(wt)Ein(w_t)是一个常数,ηη\eta是我们给定的一个正数。所以我们只关注一下实体字的部分即可。找到一个v向量,能使得这个两个向量相乘的式子越小越好。
那两个向量相乘怎么样越小呢?就是两个向量方向相反,其内积就会最小。所以说向量v的方向就是梯度的反方向,但我们要求v是一个单位向量,所以v为:
这样就得到我们那个式子的最优解了。这就是我们想要的最好的v。
此时,我们的更新式子变为:
就是沿着梯度的反方向移一小步。
那我们现在再来推导ηη\eta,什么样的ηη\eta最好。
那什么样的ηη\eta不好呢?一种是每次走一小步,走了好久。另一种是走大步一点呢,那刚才我们用泰勒展开式,用线段来代替一小段的曲线就不准了,就好比,你以为你走了一大步,可以到达谷底,其实你跨到了另一个坡上,说不定到对面坡上没有下降,反而爬的更高。仍然没有下山。这种情况就没法确定。
那什么是好的步长呢?当然是适中的一步。就是说,当我的梯度比较大时,我可以跨大步一点,当梯度比较小时,可以跨小一点,
所以说ηη\eta是在变化的,随着梯度的大小变化的。即他们是正相关的关系。那如果η||Ein(wt)||η||Ein(wt)||\frac {\eta}{||Ein(w_t)||}是成比例的关系,我们可以用新的ηη\eta来表示更新的式子,即:
所以,我们现在得到优化方法:
综上,优化方法介绍完毕。
2.4 logistics regression步骤总结
就这下图这么个过程。计算中最主要就是计算梯度。
后续:
1、GD就是在更新参数的时候,每次都是用全部的样本来进行计算。
优点是得到全局最优解,易于并行实现,缺点是当样本数多的时候,训练速度会很慢。
2、于是提出SGD,随机梯度下降法,每次随机是用一个样本来更新,更新多次。
优点是训练速度快,缺点是准确率下降,容易陷入局部最优解,不易于并行实现。
3、mini-batch GD:每次更新选择一部分样本,前两者的结合。
当然,后来又衍生了更多了优化方法,如Momentum、Nesterov、Adagrad、Adadelta、RMSprop、Adam、Adamax、NAdam。
具体可看博客:https://zhuanlan.zhihu.com/p/22252270,讲解的非常清楚。
【MLDL】logistics regression理解相关推荐
- 逻辑斯蒂回归(Logistics Regression)
Author: 吕雪杰,xiaoran; Datawhale Logistics Regression简介 逻辑回归是在数据服从伯努利分布的假设下,通过极大似然的方法,运用梯度下降法来求解参数,从而达 ...
- # Logistics Regression
目录 一 logistics regression 1 一点介绍 2 评价 3 一点应用 4 代码流程 5 上代码!!! 一 logistics regression 1 一点介绍 逻辑回归是一个分类 ...
- 机器学习:逻辑回归(logistics regression)
title: 机器学习:逻辑回归(logistics regression) date: 2019-11-30 20:55:06 mathjax: true categories: 机器学习 tags ...
- 逻辑回归(logistics regression)与 softmax
文章目录 分类与回归的区别 逻辑回归 逻辑回归不是回归 逻辑回归算法原理及公式推导 求解目标 损失函数 什么是交叉熵损失函数 梯度下降 逻辑回归怎么处理非线性分类问题 总结 softmax简介 sof ...
- python logistics regression_Python——sklearn 中 Logistics Regression 的 coef_ 和 intercept_ 的具体意义...
sklearn 中 Logistics Regression 的 coef_ 和 intercept_ 的具体意义 使用sklearn库可以很方便的实现各种基本的机器学习算法,例如今天说的逻辑斯谛 ...
- 高斯分布Gaussian distribution、线性回归、逻辑回归logistics regression
高斯分布Gaussian distribution/正态分布Normal distribution 1.广泛的存在 2020年11月24日,探月工程嫦娥五号探测器发射成功.其运转轨道至关重要,根据开普 ...
- 复盘:手推LR(逻辑回归logistics regression),它和线性回归linear regression的区别是啥
复盘:手推LR(逻辑回归logistics regression),它和线性回归linear regression的区别是啥? 提示:系列被面试官问的问题,我自己当时不会,所以下来自己复盘一下,认真学 ...
- 学习笔记6-ML(classify)-Logistics Regression
1.Logistics Regression 考虑这样一个问题: 当一堆给定的数据集XXX分别只属于class1和class2,那么对于另一个给定测试数据集xxx,XXX不包含xxx,那么xxx中各个 ...
- 机器学习必备算法之(一)逻辑回归(logistics regression)及Python实现
笔者为数学系的一个小白,最近系统的在复习机器学习以及一些深度学习的内容,准备开个博记录一下这个有趣又痛苦的过程~hiahiahia,主要记录机器学习的几大经典算法的理论以及Python的实现.非计算机 ...
- 对数几率回归——Logistics Regression原理
Logistic Regression 简介 对数几率回归,也称为逻辑回归,虽然名为"回归",但实际上是分类学习方法. 优点 不仅可以预测类别,还可以得到近似概率,对许多需要利用概 ...
最新文章
- ASP.NET中如何实现负载均衡
- 【Linux】22.当前运行的docker修改环境后,想在本地保存为镜像的方法
- Python2.7安装ncmbot时提示:Microsoft Visual C++9.0 is required
- TextWatcher() 的用法
- 即将到来的 ECMAScript 2022 新特性
- Android pda出入库管理,出入库PDA管理系统软件
- 苹果的倔强!今秋新iPhone外观设计将与2018年款非常相似
- Fatal error: Maximum function nesting level of '100' reached, aborting!
- BERT大魔王为何在商业环境下碰壁?
- [Misc]IE浏览器真正全屏幕操作技巧
- 基于二叉链表的二叉树最长路径的求解
- 基于stm32单片机srf04超声波传感器测距Proteus仿真
- html八边形怎么显示,如何用几何画板自定义工具画正八边形
- android 英语词库
- C语言 冒泡法排序,选择法排序和插入排序
- java自动转换与强制转换
- 英语题目作业(10)
- 使用better-scroll插件进行左右联动布局效果
- 告别获取不安全的相对路径-取当前类的Classpath
- 联想笔记本电脑昭阳E40睡眠恢复时蓝屏并自动快速重启的解决办法
热门文章
- 四六级、考研英语单词记忆---知米背单词APP推荐!
- 全球及中国菖蒲根提取物行业发展规模及投资方向分析报告2022-2028年
- Fingerprint2 生成浏览器指纹应用
- macbook linux 双系统,mac上安装ubuntu双系统教程
- 英文名为什么最好不用joe?JOE英文名的寓意是什么?
- 在云服务器上(Windows)手动搭建FTP站点
- index.highlight.max_analyzed_offset 偏移量设置
- 量化交易实战【1】自己搭建一个的股票交易回测框架,并通过均线择时策略进行回测
- UiPath Excel内容去重操作
- Android平台压缩纹理ETC2 VS ASTC