python实现逻辑回归的流程_逻辑回归原理及其python实现
September 28, 2018
7 min to read
逻辑回归原理及其python实现
原理
逻辑回归模型:
$h_{\theta}(x)=\frac{1}{1+e^{-{\theta}^{T}x}}$
逻辑回归代价函数:
$J(\theta)=\frac{1}{m}\sum_{i=1}^{m}Cost(h_{\theta}(x^{(i)}),y^{(i)})$
其中:
该式子合并后:
$Cost(h_{\theta}(x),y)=-ylog(h_{\theta}(x))-(1-y)log(1-h_{\theta}(x))$
即逻辑回归的代价函数:
$J(\theta)=-\frac{1}{m}[\sum_{i=1}^{m}y^{(i)}logh_{\theta}(x^{(i)})+(1-y^{(i)})log(1-h_{\theta}(x^{(i)}))]$
最小化代价函数,使用梯度下降法(gradient descent)。
Want $min_{\theta}J(\theta):$
Repeat {
$\theta_{j}=\theta_{j}-\alpha\frac{\partial}{\partial\theta_{j}}J(\theta)$
} (simultaneously updata all $\theta_{j}$,$\alpha$为学习率)
即:
Repeat {
$\theta_{j}=\theta_{j}-\alpha\sum_{i=1}^{m}(h_{\theta}(x^{(i)})-y^{(i)})x_{j}^{(i)}$
}
正则化(Regularization)
如果我们有非常多的特征,我们通过学习得到的模型可能能够非常好地适应训练集,但是可能不能推广到新的数据集,我们把这种现象成为过拟合。
为防止过拟合,提升模型泛化能力,我们需要对所有特征参数(除$\theta_{0}$外)进行惩罚,即保留所有特征,减小参数$\theta$的值,当我们拥有很多不太有用的特征时,正则化会起到很好的作用。
$J(\theta)=-[\frac{1}{m}\sum_{i=1}^{m}(y^{(i)}log(h_{\theta}(x^{(i)}))+(1-y^{(i)})log(1-h_{\theta}(x^{(i)})))]+\frac{\lambda}{2m}\sum_{j=1}^{n}\theta_{j}^{2}$
梯度下降算法:
Repeat until convergence{
$\theta_{0}=\theta_{0}-\alpha\frac{1}{m}\sum_{i=1}^{m}((h_{\theta}(x^{(i)})-y^{(i)})\cdot x_{0}^{(i)})$
$\theta_{j}=\theta_{j}-\alpha\frac{1}{m}\sum_{i=1}^{m}((h_{\theta}(x^{(i)})-y^{(i)})\cdot x_{j}^{(i)}+\frac{\lambda}{m}\theta_{j})$ $for\hspace{1em}j=1,2,…n$
}
python实现
代码
# -*- coding:UTF-8 -*-import matplotlib.pyplot as plt
import numpy as np
"""
函数说明:梯度上升算法测试函数
求函数f(x) = -x^2 + 4x的极大值
Parameters:
无
Returns:
无
"""
def Gradient_Ascent_test():
def f_prime(x_old):#f(x)的导数return -2 * x_old + 4
x_old = -1#初始值,给一个小于x_new的值x_new = 0#梯度上升算法初始值,即从(0,0)开始alpha = 0.01#步长,也就是学习速率,控制更新的幅度presision = 0.00000001#精度,也就是更新阈值while abs(x_new - x_old) > presision:
x_old = x_new
x_new = x_old + alpha * f_prime(x_old)#上面提到的公式print(x_new)#打印最终求解的极值近似值
"""
函数说明:加载数据
Parameters:
无
Returns:
dataMat - 数据列表
labelMat - 标签列表
"""
def loadDataSet():
dataMat = []#创建数据列表labelMat = []#创建标签列表fr = open('testSet.txt')#打开文件for line in fr.readlines():#逐行读取lineArr = line.strip().split()#去回车,放入列表dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])#添加数据labelMat.append(int(lineArr[2]))#添加标签fr.close()#关闭文件return dataMat, labelMat#返回
"""
函数说明:sigmoid函数
Parameters:
inX - 数据
Returns:
sigmoid函数
"""
def sigmoid(inX):
return 1.0 / (1 + np.exp(-inX))
"""
函数说明:梯度上升算法
Parameters:
dataMatIn - 数据集
classLabels - 数据标签
Returns:
weights.getA() - 求得的权重数组(最优参数)
"""
def gradAscent(dataMatIn, classLabels):
dataMatrix = np.mat(dataMatIn)#转换成numpy的matlabelMat = np.mat(classLabels).transpose()#转换成numpy的mat,并进行转置m, n = np.shape(dataMatrix)#返回dataMatrix的大小。m为行数,n为列数。alpha = 0.001#移动步长,也就是学习速率,控制更新的幅度。maxCycles = 500#最大迭代次数weights = np.ones((n,1))
for k in range(maxCycles):
h = sigmoid(dataMatrix * weights)#梯度上升矢量化公式error = labelMat - h
weights = weights + alpha * dataMatrix.transpose() * error
return weights.getA()#将矩阵转换为数组,返回权重数组
"""
函数说明:绘制数据集
Parameters:
无
Returns:
无
"""
def plotDataSet():
dataMat, labelMat = loadDataSet()#加载数据集dataArr = np.array(dataMat)#转换成numpy的array数组n = np.shape(dataMat)[0]#数据个数xcord1 = []; ycord1 = []#正样本xcord2 = []; ycord2 = []#负样本for i in range(n):#根据数据集标签进行分类if int(labelMat[i]) == 1:
xcord1.append(dataArr[i,1]); ycord1.append(dataArr[i,2])#1为正样本else:
xcord2.append(dataArr[i,1]); ycord2.append(dataArr[i,2])#0为负样本fig = plt.figure()
ax = fig.add_subplot(111)#添加subplotax.scatter(xcord1, ycord1, s = 20, c = 'red', marker = 's',alpha=.5)#绘制正样本ax.scatter(xcord2, ycord2, s = 20, c = 'green',alpha=.5)#绘制负样本plt.title('DataSet')#绘制titleplt.xlabel('X1'); plt.ylabel('X2')#绘制labelplt.show()#显示
"""
函数说明:绘制数据集
Parameters:
weights - 权重参数数组
Returns:
无
"""
def plotBestFit(weights):
dataMat, labelMat = loadDataSet()#加载数据集dataArr = np.array(dataMat)#转换成numpy的array数组n = np.shape(dataMat)[0]#数据个数xcord1 = []; ycord1 = []#正样本xcord2 = []; ycord2 = []#负样本for i in range(n):#根据数据集标签进行分类if int(labelMat[i]) == 1:
xcord1.append(dataArr[i,1]); ycord1.append(dataArr[i,2])#1为正样本else:
xcord2.append(dataArr[i,1]); ycord2.append(dataArr[i,2])#0为负样本fig = plt.figure()
ax = fig.add_subplot(111)#添加subplotax.scatter(xcord1, ycord1, s = 20, c = 'red', marker = 's',alpha=.5)#绘制正样本ax.scatter(xcord2, ycord2, s = 20, c = 'green',alpha=.5)#绘制负样本x = np.arange(-3.0, 3.0, 0.1)
y = (-weights[0] - weights[1] * x) / weights[2]
ax.plot(x, y)
plt.title('BestFit')#绘制titleplt.xlabel('X1'); plt.ylabel('X2')#绘制labelplt.show()
if __name__ == '__main__':
dataMat, labelMat = loadDataSet()
weights = gradAscent(dataMat, labelMat)
plotBestFit(weights)
python实现逻辑回归的流程_逻辑回归原理及其python实现相关推荐
- python 评分卡_评分卡原理及Python实现
信用风险计量模型可以包括跟个人信用评级,企业信用评级和国家信用评级.人信用评级有一系列评级模型组成,常见是A卡(申请评分卡).B卡(行为模型).C卡(催收模型)和F卡(反欺诈模型). 今天我们展示的是 ...
- python网络编程要学吗_总算发现如何学习python网络编程
为了提高模块加载的速度,每个模块都会在__pycache__文件夹中放置该模块的预编译模块,命名为module.version.pyc,version是模块的预编译版本编码,一般都包含Python的版 ...
- python实现守护进程_守护进程原理及Python实现
守护进程原理及Python实现 守护进程,不依赖于终端,在后台运行的程序,通常称为daemon(ˈdiːmən或ˈdeɪmən). 一些常见的Linux软件通常都是已守护进程的方式运行,比如: ngi ...
- python刷题一亩三分地_手把手教你用python抓网页数据【一亩三分地论坛数据科学版】...
前言:. visit 1point3acres.com for more. 数据科学越来越火了,网页是数据很大的一个来源.最近很多人问怎么抓网页数据,据我所知,常见的编程语言(C++,java,pyt ...
- python的gui库哪个好_常用的13 个Python开发者必备的Python GUI库
[Python](http://www.blog2019.net/tag/Python?tagId=4)是一种高级编程语言,它用于通用编程,由Guido van Rossum 在1991年首次发布.P ...
- python要和什么一起学_跟哥一起学Python(1) - python简介
01-写在前面 我做了十几年的程序猿,码过代码.带过项目.做过产品经理.做过软件架构师.因为我是做通信设备软件的,面向底层操作系统,所以我的工作主要以C语言为主.Python在我的工作中通常用来写一些 ...
- python 内存溢出能捕获吗_从0基础学习Python (19)[面向对象开发过程中的异常(捕获异常~相关)]...
从0基础学习Python (Day19) 面向对象开发过程中的=>异常 什么是异常 当程序在运行过程中出现的一些错误,或者语法逻辑出现问题,解释器此时无法继续正常执行了,反而出现了一些错误的 ...
- python里面两个大于号_【课堂笔记】Python常用的数值类型有哪些?
学习了视频课程<财务Python基础>,小编特为大家归纳了Python常用的数值类型和运算符,大家一起来查缺补漏吧~~ 数值类型 整型(int):整型对应我们现实世界的整数,比如1,2,1 ...
- python能开发小程序吗_搭建小程序用Python语言可以搭建吗?
原标题:搭建小程序用Python语言可以搭建吗? 正如我们在学习语言编程的过程中能发现各种逻辑规律的奥妙无穷那样,当我们能掌握一种语言编程方式之后,逐渐地也能深刻地感受到如今在小程序编写上还能有着怎样 ...
最新文章
- 基于Kubernetes 的机器学习工作流
- sql server扫盲系列
- Android --- 两种设置字体加粗的方法
- saltstack 自动化运维管理
- 进程状态转换(了解)
- 递归-输出字符串所有的组合情况(代码、分析、汇编)
- .Net之美读书笔记15
- Google Gears 体验(1):本机数据库
- Android6.0之AMS启动
- Java新手编程入门
- java set retainall_Java的Set集合中的retainAll()方法
- ftp服务器默认文件路径,ftp服务器默认文件路径是
- 人脸检测——基于face_recognition库
- ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.
- C# 图片位深度转至8位灰度图像,8位灰度图像转为1位灰度图像
- STM32CubeMX的使用教程
- jsp提交判空/jsp重置
- C#实现贝塞尔曲线的算法
- 电脑录屏软件哪个好用?3款屏幕录制大师分享!
- 学上位机迎来最好的时代