机器学习–逻辑斯谛回归(Logistic Regression)

基本概念

逻辑斯谛回归(Logistic Regression)虽然带回归,却是经典的分类方法。逻辑斯谛回归模型属于对数线性模型。它在线性模型的基础上,使用 Sigmoid 函数,将线性模型的结果映射到 [0, 1] 之间,实现了具体值到概率的转换。

线性回归:
f(x)=wTx+bf(x)=w^Tx + bf(x)=wTx+b
Sigmoid:
S(x)=11+e−xS(x) = \frac{1}{1 + e^{-x}}S(x)=1+e−x1​
l逻辑斯谛回归:
f(x)=S(wTx)=11+e−wTx+bf(x)=S(w^Tx)=\frac{1}{1+e^{-w^Tx +b}}f(x)=S(wTx)=1+e−wTx+b1​

逻辑斯谛分布

设 XXX 是连续随机变量,当 XXX 服从逻辑斯谛分布时,XXX 的分布函数和密度函数可以表示为(《统计机器学习》):
F(x)=P(X≤x)=11+e−(x−μ)/γF(x)=P(X\le x)=\frac{1}{1 + e^{-(x-\mu)/ \gamma}}F(x)=P(X≤x)=1+e−(x−μ)/γ1​
f(x)=F′(x)=e−(x−μ)/γγ(1+e−(x−μ)/γ)2f(x)=F^{'}(x)=\frac{e^{-(x-\mu)/\gamma}}{\gamma(1+e^{-(x-\mu)/\gamma})^2}f(x)=F′(x)=γ(1+e−(x−μ)/γ)2e−(x−μ)/γ​
其中,μ\muμ 为位置参数,γ>0\gamma > 0γ>0 为形状参数。

二项逻辑斯谛回归

二项逻辑斯谛回归模型(binomial logistic regression model)是一种分类模型。并且分类结果只有0和1两种。
P(Y=1∣x)=ewTx+b1+ewTx+bP(Y=1|x)=\frac{e^{w^Tx + b}}{1+e^{w^Tx + b}}P(Y=1∣x)=1+ewTx+bewTx+b​
P(Y=0∣x)=11+ewTx+bP(Y=0|x)=\frac{1}{1+e^{w^Tx + b}}P(Y=0∣x)=1+ewTx+b1​
在求得两个概率之后,逻辑斯谛回归模型比较这两个条件概率值的大小,将实例 xxx 分到概率值较大的那一类。

逻辑斯谛回归模型的特点:一个事件的几率是指该事件发生的概率与该时间不发生的概率的比值。对于逻辑斯谛回归而言,P(Y=1∣x)P(Y=1|x)P(Y=1∣x) 发生的对数几率是:
logP(Y=1∣x)1−P(Y=1∣x)=wTx+blog\frac{P(Y=1|x)}{1-P(Y=1|x)}=w^Tx+blog1−P(Y=1∣x)P(Y=1∣x)​=wTx+b
这表明,在逻辑斯谛回归模型中,输出 Y = 1 的对数几率是输入 xxx 的线性函数。另一种说法是,输出 Y = 1 的对数几率是由输入的线性函数表示的模型,即逻辑斯谛回归模型。

代价函数

如果我们使用线性回归中的代价函数进行梯度下降法计算,有可能得到局部最优解,无法求得最优的参数值,因此我们使用极大似然估计 求得代价函数。

我们以二分类为例(二项逻辑斯谛回归)。
令:
P(Y=1∣x)=π(x),P(Y=0∣x)=1−π(x)P(Y=1|x)=\pi(x),\ P(Y=0|x)=1-\pi(x)P(Y=1∣x)=π(x), P(Y=0∣x)=1−π(x)
可以得到似然函数为:
∏i=1N[π(xi)]yi[1−π(xi)]1−yi\prod_{i=1}^N[\pi(x_i)]^{y_i}[1-\pi(x_i)]^{1-y_i}i=1∏N​[π(xi​)]yi​[1−π(xi​)]1−yi​
取对数
L(w)=∑i=1N[yilogπ(xi)+(1−yi)log(1−π(xi))]=∑i=1N[yilogπ(xi)1−π(xi)+log(1−π(xi))]=∑i=1N[yi(wTxi+b)−log(1+ewTxi+b)]L(w)=\sum_{i=1}^N[y_i log \pi(x_i) + (1-y_i)log(1-\pi(x_i))] \\ \ \ \ \ =\sum_{i=1}^N [y_ilog\frac{\pi(x_i)}{1-\pi(x_i)}+log(1-\pi(x_i))]\\ =\sum_{i=1}^N[y_i(w^T x_i+b)-log(1+e^{w^Tx_i+b})] \ \ \ \ \ \ \ L(w)=i=1∑N​[yi​logπ(xi​)+(1−yi​)log(1−π(xi​))]    =i=1∑N​[yi​log1−π(xi​)π(xi​)​+log(1−π(xi​))]=i=1∑N​[yi​(wTxi​+b)−log(1+ewTxi​+b)]       
对 L(w)L(w)L(w) 求极大值,得到 www 的估计值。

令代价函数为
J(w)=−L(w)J(w)=-L(w)J(w)=−L(w)
接下来用梯度下降法或牛顿法来求解参数。我们先求出偏导数 ∂J(w)∂w\frac{\partial J(w)}{\partial w}∂w∂J(w)​ 和 ∂J(w)∂b\frac{\partial J(w)}{\partial b}∂b∂J(w)​:
∂J(w)∂w=1Nxi(wTxi+b−yi)\frac{\partial J(w)}{\partial w}=\frac{1}{N}x_i(w^Tx_i+b-y_i)∂w∂J(w)​=N1​xi​(wTxi​+b−yi​)
∂J(w)∂b=1N∑i=1N(wTxi+b−yi)\frac{\partial J(w)}{\partial b}=\frac{1}{N}\sum_{i=1}^N(w^Tx_i+b-y_i)∂b∂J(w)​=N1​i=1∑N​(wTxi​+b−yi​)

在求得偏导数之后,我们对参数进行迭代更新

w=w−lr∗∂J(w)∂ww=w-lr*\frac{\partial J(w)}{\partial w}w=w−lr∗∂w∂J(w)​
b=b−lr∗∂J(w)∂bb=b-lr*\frac{\partial J(w)}{\partial b}b=b−lr∗∂b∂J(w)​

经过多次迭代后,我们的模型能学习到一个很好的参数。

实例

# coding=utf-8
from sklearn.datasets import make_classification
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as pltdef load_data():X, y = make_classification(n_samples=10000, n_features=10, n_redundant=3, n_informative=3, n_classes=2,n_clusters_per_class=1, random_state=42)return X, yclass LogisticRegression:def __init__(self, lr=0.01, epoch=1000):# 学习率self.lr = lr# 迭代次数self.epochNum = epochself.train_loss = []# 定义 sigmoid 函数def sigmoid(self, z):return 1 / (1 + np.exp(-z))# 用来计算梯度 dw 和 dbdef gradient(self, X, y):h = self.sigmoid(np.dot(X, self.w) + self.b)dw = np.dot(X.T, (h - y)) / X.shape[0]db = np.sum(h - y) / X.shape[0]return dw, db# 计算损失def loss(self, X, y):m = len(y)h = self.sigmoid(np.dot(X, self.w) + self.b)# 参考公式cost = (-1 / m) * (np.dot(y.T, np.log(h)) + np.dot((1 - y).T, np.log(1 - h)))cost = cost.squeeze()# 带有 L2 正则化项reg = (self.lr / 2) * np.sum(np.dot(self.w.T, self.w))return cost + reg# 拟合函数def fit(self, X, y):# 输入要为二维,将标签增维y = y.reshape(-1, 1)# 初始化参数 w 和 bself.w = np.zeros((X.shape[1], 1))self.b = np.zeros((X.shape[0], 1))for epoch in range(self.epochNum):# 计算损失和梯度loss = self.loss(X, y)dw, db = self.gradient(X,  y)# 更新参数 w 和 bself.w = self.w - self.lr * dwself.b = self.b - self.lr * dbself.train_loss.append(loss)def show_parameters(self):print('w:', self.w)def predict(self, x_test):prediction =self.sigmoid(np.dot(x_test, self.w))prediction = prediction > 0.5prediction = prediction.astype(int)return prediction# 绘图def plot_loss(self):plt.figure()plt.plot(range(len(self.train_loss)), self.train_loss, 'r')plt.title("Convergence Graph")plt.xlabel("Epochs")plt.ylabel("Train Loss")plt.show()if __name__ == '__main__':X, y = load_data()x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.3)model = LogisticRegression()model.fit(x_train, y_train)model.plot_loss()pre = model.predict(x_test)acc = accuracy_score(y_test, pre)print(acc)

机器学习--逻辑斯谛回归(Logistic Regression)相关推荐

  1. 机器学习笔记:logistic regression

    1 逻辑回归介绍 logistic regressioin是一种二分类算法,通过sigmoid激活函数将线性组合压缩到0和1之间,来代表属于某一个分类的属性 虽然其中带有"回归"两 ...

  2. 瞎聊机器学习——LR(Logistic Regression)逻辑斯蒂回归(一)

    逻辑斯蒂回归是我们在学习以及工作中经常用到的一种分类模型,下面通过本文来讲解一下逻辑斯蒂回归(logistic regression,下文简称LR)的概念.数学推导. 一.逻辑斯蒂回归的概念 首先希望 ...

  3. 【ML】 李宏毅机器学习二:Logistic Regression

    我们将在分类模型基础上继续,并开始学习一种常用的分类算法--Logistic回归,逻辑回归logistic regression,虽然名字是回归,但是实际上它是处理分类问题的算法.简单的说回归问题和分 ...

  4. 机器学习实践一 logistic regression regularize

    Logistic regression 数据内容: 两个参数 x1 x2 y值 0 或 1 Potting def read_file(file):data = pd.read_csv(file, n ...

  5. FlyAI小课堂:Python机器学习笔记:Logistic Regression

    Logistic回归公式推导和代码实现 1,引言 logistic回归是机器学习中最常用最经典的分类方法之一,有人称之为逻辑回归或者逻辑斯蒂回归.虽然他称为回归模型,但是却处理的是分类问题,这主要是因 ...

  6. 2018-3-20李宏毅机器学习笔记十----------Logistic Regression

    上节讲到:既然是一个直线型,只需要求解w和b.为何还要那么费劲的使用概率??? 视频:李宏毅机器学习(2017)_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili https://www.bilib ...

  7. 吴恩达机器学习 -- 逻辑回归(Logistic Regression)

    7.1  分类问题 如果要预测的变量 是离散值,此时我们应用 logistics regression. 在分类问题中,我们对某一事物进行分类,有二分类和多分类问题.此节先讨论二分类问题,即只有两个分 ...

  8. Python遇见机器学习 ---- 逻辑回归 Logistic Regression

    综述 "子非鱼,焉知鱼之乐" 本文采用编译器:jupyter 逻辑回归方法是从线性回归方法发展过来的,通常解决的是分类问题,读者或许有这样一个疑问:既然是回归算法又么解决分类问题的 ...

  9. 在opencv3中实现机器学习之:利用逻辑斯谛回归(logistic regression)分类

    logistic regression,注意这个单词logistic ,并不是逻辑(logic)的意思,音译过来应该是逻辑斯谛回归,或者直接叫logistic回归,并不是什么逻辑回归.大部分人都叫成逻 ...

最新文章

  1. NVelocity:语法
  2. snort2安装及卸载教程
  3. 实际价格计算:确定方法
  4. java socket参数详解:OOBInline和UrgentData .
  5. PostgreSQL查看版本信息
  6. 修改USB固件库的Customer_HID例程
  7. thinkphp3.2.3漏洞_Chrome新版本修复CVE202015999 0 day漏洞
  8. java中HashSet实现(转)
  9. 微信小程序布局 底部位置固定例子
  10. delphi10 ftp文件名乱码问题
  11. mysql数据库配置步骤,MySQL数据库安装配置步骤详解
  12. 常见容错机制:failfast、failsafe、failover、failback
  13. 戴德金之连续性和无理数的中文翻译
  14. 别再说不知道元空间和永久代的区别了
  15. html实现网页多人聊天,实现websocket多人聊天,很简单(示例代码)
  16. 折线迷你图怎么设置_Excel如何制作小微型的迷你图?
  17. SCCM规划 - 网络
  18. SpringBoot 实现 QQ邮箱注册和登录
  19. C++11标准模板(STL)- 算法(std::nth_element)
  20. Python: 向量、矩阵和多维数组(基于NumPy库)

热门文章

  1. SpringBoot2.0 整合 FastDFS 中间件,实现文件分布式管理
  2. Linux IPC实践(5) --System V消息队列(2)
  3. GRUB与Linux系统修复(第二版)
  4. 【JUC】JDK1.8源码分析之ConcurrentHashMap
  5. Nacos 发布 1.0.0 GA 版本,可大规模投入到生产环境
  6. c语言 main函数
  7. LINUX的简单命令
  8. web前端安全编码(模版篇)
  9. 信息系统工程监理服务及营销策略
  10. mysql异地增量备份工具_利用 xtrabackup 工具实现增量备份 mysql(附脚本)