目录

1.简介

你可以使用线性模型来拟合非线性数据。
一个简单的方法就是将每个特征的幂次方添加为一个新特征,然后在此扩展特征集上训练一个线性模型。这种技术称为多项式回归

2.举例

y=0.5x^2+x+2,加一些噪声,生成一些非线性数据

import numpy as np
import numpy.random as rnd
import matplotlib.pyplot as pltnp.random.seed(42)m = 100
X = 6 * np.random.rand(m, 1) - 3
y = 0.5 * X**2 + X + 2 + np.random.randn(m, 1)plt.plot(X, y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.axis([-3, 3, 0, 10])
plt.show()

显然,一条直线永远无法正确地拟合此数据。 

让我们使用Scikit-Learn的PolynomialFeatures类来转换训练数据,将训练集中每个特征的平方添加为新特征

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeaturespoly_features = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly_features.fit_transform(X) # X_poly现在包含X的原始特征以及该特征的平方lin_reg = LinearRegression()
lin_reg.fit(X_poly, y)
lin_reg.coef_, lin_reg.intercept_
# 输出:(array([[0.93366893, 0.56456263]]), array([1.78134581]))

模型估算结果:y=0.56x^2 + 0.93x + 1.78而实际上原始函数:y=0.5x^2+x+2+高斯噪声
我们画下图看看 

# 画预测曲线,用np.linspace以均匀步长生成数字序列,这里,-3到3分成100份
# 知道间隔用np.arrange生成序列,不知道间隔,只知道要分多少份用np.linspace
X_new=np.linspace(-3, 3, 100).reshape(100, 1)
X_new_poly = poly_features.transform(X_new)
y_new = lin_reg.predict(X_new_poly)plt.plot(X, y, "b.")
plt.plot(X_new, y_new, "r-", linewidth=2, label="Predictions")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend(loc="upper left", fontsize=14)
plt.axis([-3, 3, 0, 10])
plt.show()

当存在多个特征时,多项式回归能够找到特征之间的关系。
PolynomialFeatures还可以将特征的所有组合添加到给定的多项式阶数。
例如,如果有两个特征a和b,则degree=3的PolynomialFeatures不仅会添加特征a^2、a^3、b^2和b^3,还会添加组合ab、a^2b和ab^2。

3.学习曲线

①不同阶数的多项式回归 

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipelinefor style, width, degree in (("g-", 1, 300), ("b--", 2, 2), ("r-+", 2, 1)):polybig_features = PolynomialFeatures(degree=degree, include_bias=False)std_scaler = StandardScaler()lin_reg = LinearRegression()polynomial_regression = Pipeline([("poly_features", polybig_features),("std_scaler", std_scaler),("lin_reg", lin_reg),])polynomial_regression.fit(X, y)y_newbig = polynomial_regression.predict(X_new)plt.plot(X_new, y_newbig, style, label=str(degree), linewidth=width)plt.plot(X, y, "b.", linewidth=3)
plt.legend(loc="upper left")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.axis([-3, 3, 0, 10])
plt.show()

可见,高阶多项式回归模型严重过拟合训练数据,而线性模型则欠拟合,最能泛化的模型是二次模型。但是一开始你不知道数据由什么函数生成,那么如何确定模型的复杂性呢?你如何判断模型是过拟合数据还是欠拟合数据呢?

之前我们使用交叉验证来估计模型的泛化性能。如果模型在训练数据上表现良好,但根据交叉验证的指标泛化较差,则你的模型过拟合。如果两者的表现均不理想,则说明欠拟合。
还有一种方法是观察学习曲线:这个曲线绘制的是模型在训练集验证集上关于训练集大小的性能函数。要生成这个曲线,只需要在不同大小的训练子集上多次训练模型。

②普通线性回归模型的学习曲线

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_splitdef plot_learning_curves(model, X, y):X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=10)train_errors, val_errors = [], []for m in range(1, len(X_train) + 1):model.fit(X_train[:m], y_train[:m])y_train_predict = model.predict(X_train[:m])y_val_predict = model.predict(X_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend(loc="upper right", fontsize=14)plt.xlabel("Training set size", fontsize=14)plt.ylabel("RMSE", fontsize=14)lin_reg = LinearRegression()
plot_learning_curves(lin_reg, X, y)
plt.axis([0, 80, 0, 3])
plt.show()

这种模型是欠拟合的:当训练集中只有一个或两个实例时,模型可以很好地拟合它们,随着将新实例添加到训练集中,它开始学习,验证错误逐渐降低,但是,直线不能很好地对数据进行建模,因此误差最终达到一个平稳的状态。
欠拟合的模型,添加更多训练示例将无济于事,你需要使用更复杂的模型或提供更好的特征。

③相同数据上的10阶多项式模型的学习曲线

from sklearn.pipeline import Pipelinepolynomial_regression = Pipeline([("poly_features", PolynomialFeatures(degree=10, include_bias=False)),("lin_reg", LinearRegression()),
])plot_learning_curves(polynomial_regression, X, y)
plt.axis([0, 80, 0, 3])
plt.show()

与线性回归模型相比,训练数据上的误差要低得多
该模型在训练集上的性能要比在验证集上的好得多,说明模型过拟合
改善过拟合模型的一种方法是:向其提供更多的训练数据

多项式回归、学习曲线相关推荐

  1. Python机器学习:多项式回归与模型泛化005学习曲线

    KNN和多项式回归不适合上面这样绘制这样,决策树适合绘制 学习曲线 CODE: #数据 import numpy as np import matplotlib.pyplot as plt x = n ...

  2. 吴恩达机器学习课后作业——偏差和方差

    1.写在前面 吴恩达机器学习的课后作业及数据可以在coursera平台上进行下载,只要注册一下就可以添加课程了.所以这里就不写题目和数据了,有需要的小伙伴自行去下载就可以了. 作业及数据下载网址:吴恩 ...

  3. 《机器学习实战:基于Scikit-Learn、Keras和TensorFlow(第2版)》学习笔记

    文章目录 书籍信息 技术和工具 Scikit-Learn TensorFlow Keras Jupyter notebook 资源 书籍配套资料 流行的开放数据存储库 元门户站点(它们会列出开放的数据 ...

  4. 第8章 多项式回归与模型泛化

    问题:线性回归要求假设我们的数据背后存在线性关系: , 如果将x的平方理解成一个特征,x理解成另一个特征:本来只有一个特征x,现在看成有两个特征的数据集,多了一个特征,就是x的平方,其实式子本身依然是 ...

  5. 十二、案例:加利福尼亚房屋价值数据集(多元线性回归) Lasso 岭回归 分箱处理非线性问题 多项式回归

    案例:加利福尼亚房屋价值数据集(线性回归)& Lasso & 岭回归 & 分箱处理非线性问题 点击标题即可获取文章源代码和笔记 1. 导入需要的模块和库 from sklear ...

  6. 机器学习-Sklearn-13(回归类大家族-下——非线性问题:多项式回归(多项式变换后形成新特征矩阵))

    机器学习-Sklearn-13(回归类大家族-下--非线性问题:多项式回归(多项式变换后形成新特征矩阵)) 5 非线性问题:多项式回归 5.1 重塑我们心中的"线性"概念 在机器学 ...

  7. 机器学习----多项式回归

    多项式回归简介 考虑下面的数据,虽然我们可以使用线性回归来拟合这些数据,但是这些数据更像是一条二次曲线,相应的方程是y=ax2+bx+c,这是式子虽然可以理解为二次方程,但是我们呢可以从另外一个角度来 ...

  8. 10、欠或过拟合的学习曲线,运用验证集选取正则化的L值

    ''' 在本练习中,您将实现正则化的线性回归和多项式回归,并使用它来研究具有不同偏差-方差属性的模型. 在前半部分的练习中,你将实现正则化线性回归,以预测水库中的水位变化,从而预测大坝流出的水量. 在 ...

  9. learning_curve(学习曲线)

    文章目录 一.学习曲线 1.1 学习曲线作用 1.2 学习曲线 1.3 表现能力 1.4 学习曲线的具体操作 二.实例 2.1 模拟数据集 2.2 绘制数据集 2.3 分割数据集 2.4 使用线性回归 ...

最新文章

  1. Nat. Commun. | 多层生物分子网络的鲁棒性研究
  2. 一篇博客带你轻松应对java面试中的多线程与高并发
  3. javaweb学习总结(三十)——EL函数库
  4. scjp考试准备 - 1 - 循环控制
  5. 浅谈BP神经网络的Matlab实现
  6. 产品研发过程管理专题——基于产品的测试管理(用友软件测试流程初探)
  7. jmoiron sqlx mysql_mysql 一(或其他数据库)
  8. 防火墙状态检测及会话表技术
  9. HDOJ试水心酸总结
  10. 厦门宏发有机器人_2020新版福建省厦门工业机器人工商企业公司名录名单黄页大全23家...
  11. ​LeetCode刷题实战248:中心对称数III
  12. 解决VUE打印时多一页空白页的问题
  13. python 象棋 ai 入门教程-用turtle画中国象棋棋盘
  14. zjb_integrated 的BLOG(学习DaVinci的好文章)
  15. VMware SDS之11: VMware SPBM之DELL SC(也即Compellent)篇
  16. js 事件回调函数的对象属性说明:clientX、screenX、offsetX、pageX
  17. 实例讲解基于 React+Redux 的前端开发流程
  18. 休闲游戏合成植物打僵尸源码-H5+安卓+IOS三端源码
  19. App ID申请(将项目中的ID向苹果申请)
  20. 学计算机cpu重要还是显卡重要,显卡处理器和内存 吃鸡时哪个最重要?

热门文章

  1. 基于Unity3D经典消消乐游戏源码,代码详细注释,c#版方块消消乐源代码
  2. java快捷键,补全
  3. RTP载荷PS流全面分析
  4. Windows启动原理
  5. mysql 备份 第三方工具_目前主流的数据库备份第三方工具都有哪些比较好用
  6. 最后一战——回顾 NOIP 2021
  7. linux win10自带浏览器,win10系统下如何安装opera浏览器
  8. 怎么合并多个PDF文件?看完这篇你就会了
  9. java反编译教程_Java反编译工具 - JD-GUI 下载地址及使用手册
  10. 在线直线度测量方法的研发方向