1. 3 线性回归的sklearn实现

导入必要的模块

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression#搭建模型
from sklearn.metrics import mean_squared_error#计算评价指标mse

数据集

x = np.array([50, 30, 15, 40, 55, 20, 45, 10, 60, 25])
y = np.array([5.9, 4.6, 2.7, 4.8, 6.5, 3.6, 5.1, 2.0, 6.3, 3.8])

画出数据集的散点图

plt.scatter(x, y)
plt.grid(True)
plt.xlabel('area')
plt.ylabel('price')
plt.show()

数据划分

划分训练集和测试集

使用到的api:

数据划分sklearn.model_selection.train_test_split

用到的参数:

  • *arrays:输入数据集。

  • test_size:划分出来的测试集占总数据量的比例,取值0~1。

  • shuffle:是否在划分前打乱数据的顺序,默认True。

  • random_state:shuffle的随机种子,取值正整数。

返回:

  • splitting:列表包含划分后的训练集与测试集。
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, shuffle=True, random_state=23)

查看训练集的散点图

plt.scatter(x_train,y_train)
plt.grid('True')
plt.xlabel('area')
plt.ylabel('price')
plt.show()

查看测试集的散点图

plt.scatter(x_test,y_test)
plt.grid('True')
plt.xlabel('area')
plt.ylabel('price')
plt.show()

模型搭建

使用到的api:

线性回归sklearn.linear_model.LinearRegression

model = LinearRegression()

模型训练

使用到的api:

线性回归模型训练sklearn.linear_model.LinearRegression.fit

用到的参数:

  • X:输入特征,如果输入是np.array格式,shape必须是(n_sample, n_feature)。

  • y:输入标签。

# x_train的shape由(7,)变为(7,1)
x_train = x_train.reshape(-1,1)
print(x_train)
model.fit(X=x_train, y=y_train)

模型预测

对测试集做预测

使用到的api:

线性回归模型预测sklearn.linear_model.LinearRegression.predict

用到的参数:

  • X:输入特征,如果输入是np.array格式,shape必须是(n_sample, n_feature)。

返回:

  • C:预测结果。
# x_test的shape由(10,1)变为(10,)
x_test = x_test.reshape(-1,1)y_test_pred = model.predict(x_test)
print(y_test_pred)

预测结果:

[3.16995192 6.82572115 2.71298077]

画出数据集的散点图和预测直线

x_test = x_test.reshape(-1)plt.scatter(x_test, y_test, color='g', label='test dataset')
plt.scatter(x_train, y_train, color='b',label='train dataset')
plt.plot(np.sort(x_test), y_test_pred[np.argsort(x_test)], color='r', label='linear regression')
plt.legend()
plt.show()

计算评价指标mse

使用到的api:

均方误差sklearn.metrics.mean_squared_error

用到的参数:

  • y_true:真实值(ground truth)。

  • y_pred:预测值。

返回:

  • loss:mse计算结果。
mse = mean_squared_error(y_true=y_test, y_pred=y_test_pred)
print('MSE: {}'.format(mse))
MSE: 0.15383086014546365

查看线性回归模型的系数w和截距b

使用到的api:

回归系数sklearn.linear_model.LinearRegression.coef_

截距项sklearn.linear_model.LinearRegression.intercept_

w, b = model.coef_[0], model.intercept_
print('Weight={0} bias={1}'.format(w, b))
Weight=0.09139423076923077 bias=1.3420673076923069
```

1.3 线性回归的sklearn实现相关推荐

  1. 机器学习之线性回归 (Python SKLearn)

    import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.font_manager ...

  2. Python中的线性回归:Sklearn与Excel

    内部AI (Inside AI) Around 13 years ago, Scikit-learn development started as a part of Google Summer of ...

  3. 使用线性回归识别sklearn中的手写数字digit

    从昨天晚上,到今天上午12点半左右吧,一直在调这个代码.最开始训练的时候,老是说loss:nan 查了资料,因为是如果损失函数使用交叉熵,如果预测值为0或负数,求log的时候会出错.需要对预测结果进行 ...

  4. python预测模型_Python多元线性回归-sklearn.linear_model,并对其预测结果评估

    在前面的博客已经介绍过多元回归模型,现在这里粗略介绍如下 python 实现案例 1.选取数据 执行代码#!usr/bin/env python#_*_ coding:utf-8 _*_import  ...

  5. 线性回归原理与spark/sklearn实现

    线性回归原理与spark/sklearn实现 @(SPARK)[spark, ML] 一.算法原理 1.线程回归与逻辑回归的区别 线性回归是一种很直观的数值拟合方式,它认为目标变量和属性值之间存在线性 ...

  6. sklearn 中的线性回归、岭回归、Lasso回归参数配置及示例

    文章目录 线性回归 引入 重要参数 重要属性 重要方法 例子 岭回归 引入 重要参数 重要属性 重要方法 示例 Lasso 回归 引入 重要参数 重要属性 重要方法 示例 本文主要讲一些sklearn ...

  7. 【sklearn】线性回归、最小二乘法、岭回归、Lasso回归

    文章目录 机器学习的sklearn库 一.回归分析 <1.1>线性回归 1.1.1.Python实现线性回归 <1.2>最小二乘法 1.2.1.MATLAB实现最小二乘法 1. ...

  8. 机器学习算法——线性回归的详细介绍 及 利用sklearn包实现线性回归模型

    目录 1.线性回归简介 1.1 线性回归应用场景 1.2 什么是线性回归 1.2.1 定义与公式 1.2.2 线性回归的特征与目标的关系分析 2.线性回归api初步使用 2.1 线性回归API 2.2 ...

  9. sklearn实现一元线性回归

    sklearn实现一元线性回归 导入sklearn以及相关库 from sklearn.linear_model import LinearRegression import numpy as np ...

最新文章

  1. CGIC简明教程(转摘)
  2. JavaScript深入理解对象方法——Object.entries()
  3. 【转】GPS定位原理
  4. SVN下最高效打基线方法
  5. 大数据视域下网络涉军舆情管控研究
  6. 解决vue router使用 history 模式刷新后404问题
  7. C/C++语言函数参数里的“...”作用,va_list的使用(stdarg.h)
  8. 基于JAVA+Servlet+JSP+MYSQL的教师教学评价系统
  9. 银监会计算机知识点,2015国考银监会计算机专业考试分析
  10. Linux进程管理及作业控制(转)
  11. KITTI激光雷达点云解析与图像反投影
  12. TCP的短链接和长连接
  13. jQuery print 去掉页眉页脚
  14. provision文件路径
  15. 第一章 80C51单片机概述
  16. 分布式系统高可用实战之限流器(Go 版本实现)
  17. IntelliJ IDEA更换代码字体为Consolas
  18. 七年级计算机教案部编,七年级信息技术教案新部编本.docx
  19. 你做一篇微信公众号文章要多久?
  20. linux备份mysql文件并恢复的脚本,以及其中出现的错误:ERROR: ASCII '\0' appeared in the statement...

热门文章

  1. OpenGL基础52:阴影映射(上)
  2. 2018年全国多校算法寒假训练营练习比赛(第一场)D. N阶汉诺塔变形(找规律)
  3. 2017CCPC哈尔滨 M:Geometry Problem(随机)
  4. HDU 5976 2016ICPC大连 F: Detachment(找规律)
  5. matlab2c使用c++实现matlab函数系列教程- poly函数
  6. python实现将文件夹/子文件夹中内容清空
  7. (Electronic WorkBench)EWB仿真JK触发器
  8. 数电渣渣的一点学习感想(更新中)
  9. Jumpserver安装和总结
  10. CognitiveJ一个Java的人脸图像识别开源分析库