Python机器学习:线型回归法008实现多元线性回归
使用封装的:LinearRegression
import numpy as np
from Simple_linear_Regression.metrics import r2_score
class LinearRegression:def __init__(self):"""初始化Linear Regression 模型"""self.coef_ = Noneself.interception_ = Noneself._theta = Nonedef fit_normal(self,X_train,y_train):"""根据训练数据集X_train,y_train训练Linear Regression"""assert X_train.shape[0] == y_train.shape[0],\"the size of X_train must be equal to the size of y_train"X_b = np.hstack([np.ones((X_train.shape[0],1)),X_train])#X_b = np.hstack([np.ones((len(X_train)),1),X_train])self._theta = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y_train)self.interception_ = self._theta[0]self.coef_ = self._theta[1:]return selfdef predict(self,X_predict):"""给定待测数据集X_predict,返回表示X_predict 的结果向量"""assert self.interception_ is not None and self.coef_ is not None,\"must fit before predict"assert X_predict.shape[1] == len(self.coef_),\"the feature number of X_predict must be equal to X_train"X_b = np.hstack([np.ones((X_predict.shape[0], 1)), X_predict])return X_b.dot(self._theta)def score(self,X_test,y_test):"""根据测试数据集X_test和y_test确定当前模型的准确度"""y_predict = self.predict(X_test)return r2_score(y_test,y_predict)def __repr__(self):return "LinearRegression()"
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
依然房价数据
boston = datasets.load_boston()
X = boston.data
y = boston.targetX = X[y < 50]
y = y[y < 50]
from Simple_linear_Regression.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,seed = 666)
LinearRegression()
from Simple_linear_Regression.LinearRegression import LinearRegression
reg = LinearRegression()
reg.fit_normal(X_train,y_train)
截距和参数
print(reg.interception_)
print(reg.coef_)
34.117399723204585
[-1.20354261e-01 3.64423279e-02 -3.61493155e-02 5.12978140e-02-1.15775825e+01 3.42740062e+00 -2.32311760e-02 -1.19487594e+002.60101728e-01 -1.40219119e-02 -8.35430488e-01 7.80472852e-03-3.80923751e-01]
与一元线性模型比较,预测准确率上升了
reg.score(X_test,y_test)
0.8129794056212895
Python机器学习:线型回归法008实现多元线性回归相关推荐
- 《机器学习笔记(三):多元线性回归与正态分布最大似然估计》
回归问题普遍讨论的是多元线性回归,考虑多个特征可以得到更精确的模型,这其中涉及中心极限定理,正态分布,概率密度函数和最大似然估计. (一)背景--多元线性回归 1.概念 本质上就是算法(公式)变换为了 ...
- 基于逐步法思想的多元线性回归(雏形)
基于逐步法思想的多元线性回归(雏形),后续修改看时间安排,有时间会进行完善. import pandas as pd import numpy as np from scipy.stats impor ...
- python机器学习——支持向量机回归与波士顿房价案例
支持向量机回归与波士顿房价案例 一.从传统回归模型到支持向量回归模型 二.核函数 三.常用的几种核函数 四.SVM 算法的优缺点 五.建模实例 (1)导入数据 (2)划分训练集测试集 (3)数据标准化 ...
- pythonsklearn多元回归回归_sklearn入门之多元线性回归
原标题:sklearn入门之多元线性回归 本文作者:杨长青 本文编辑:胡 婧 技术总编:张学人 scikit-learn又称sklearn是基于python的一个强大的机器学习库,它建立在numpy, ...
- Python金融系列第五篇:多元线性回归和残差分析
作者:chen_h 微信号 & QQ:862251340 微信公众号:coderpai 第一篇:计算股票回报率,均值和方差 第二篇:简单线性回归 第三篇:随机变量和分布 第四篇:置信区间和假设 ...
- 吴恩达《机器学习》第四章:多元线性回归
目录 四.多元线性回归 4.1 特征缩放 4.2 学习率α 4.4 特征和多项式 4.4 正规方程 四.多元线性回归 多特征下的假设形式: 4.1 特征缩放 特征缩放:Feature Scaling, ...
- Python机器学习:线型回归法007多元线性回归和正规方程的解
- Python机器学习:线型回归法05衡量线性回归法的指标MES,RMS,MAE
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets 数据 boston = datasets ...
- Python机器学习:线型回归法06最好的衡量线型回归法的指标RSquared
代码实现首先使用sklearn.metrics中的r2.score 依然使用波士顿房价数据集 #使用sklearn from sklearn.metrics import r2_score r2_sc ...
最新文章
- Windbg+sos调试.net笔记
- 减少静态链接库的体积
- 写文件头的算法流程及C代码实现
- 移动互联网下一步:“深度学习”配合大数据
- 清除java_如何在Java地毯下有效地清除问题
- 面试时遇到「看门狗」脖子上挂着「时间轮」,我就问你怕不怕?
- android opencv 银行卡识别,NDK 开发之使用 OpenCV 实现银行卡号识别
- java array 元素的位置_JAVA集合类,有这一篇就够了
- Javascript 随机验证码
- 一件程序员必备武器的诞生
- java.io.FileWriter class doesn’t use UTF-8 by default
- MPQ8633性能指标测试与调测分享
- 干货培训 | 使用OBS进行直播导播和推流(下篇)
- win10 jungo windriver
- PS cc 2019自由变换默认等比例缩放操作问题的解决方法
- android在wifi和4G网络都可以使用的情况下,设置每次请求使用的网络类型
- 计算机常见故障原因有哪些,电脑常见故障原因及解决方法
- Timeboxing——业界大佬都在用的时间管理法
- Data too long for column解决方法
- Linux的安装(一步一步教你安装Linux)
热门文章
- Spring Boot热部署
- 如何使用be动词来确认请求_12
- flutter 动画展开菜单_蒲公英 · JELLY技术周刊 Vol.34: 芜湖~ Flutter
- linux自动挂载ntfs分区,Ubuntu 12.04 开机自动挂载ntfs分区
- 【英语学习】【English L06】U03 House L3 How is your house hunt going?
- 2 安装失败_写bug日记2:PYTORCH GEOMETRIC安装失败的问题(未解决)
- 大工19春《计算机组成原理》,大工19春《计算机组成原理》在线作业3.doc
- java 指代对象_06JAVA面向对象之封装
- python打印表格_怎么使用python脚本实现表格打印?
- zypper 删除mysql_如何在 Linux 上安装/卸载一个文件中列出的软件包?