KNN和多项式回归不适合上面这样绘制这样,决策树适合绘制

学习曲线




CODE:

#数据
import numpy as np
import matplotlib.pyplot as plt
x = np.random.uniform(-3,3,size=100)
#在最新版本的sklearn中,所有的数据都应该是二维矩阵,哪怕它只是单独一行或一列。
X = x.reshape(-1,1)
y = 0.5 * x ** 2 + x + 2 +np.random.normal(0,1,size=100)
plt.scatter(x,y)#非线性关系
#print(X)


tranin_test_split

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=10)
print(X_train.shape)
(75, 1)

绘制学习曲线

#学习曲线
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
train_score = []
test_score = []for i in range(1,76):lin_reg = LinearRegression()lin_reg.fit(X_train[:i],y_train[:i])y_train_predict = lin_reg.predict(X_train[:i])train_score.append(mean_squared_error(y_train[:i],y_train_predict))y_test_predict = lin_reg.predict(X_test)test_score.append(mean_squared_error(y_test,y_test_predict))

训练数据集误差逐渐累积。。到了一定程度拟合稳定
测试数据集 逐渐减小 然后 稳定
测试误差还是比 训练误差小一点

#训练数据集误差逐渐累积。。到了一定程度拟合稳定
#测试数据集 逐渐减小 然后 稳定
#测试误差还是比 训练误差小一点
plt.plot([i for i in range(1,76)],np.sqrt(train_score),label = "train")
plt.plot([i for i in range(1,76)],np.sqrt(test_score),label = "test")
plt.legend()


封装绘制学习曲线代码


#绘制学习曲线代码
def plot_learning_curve(algo,X_train,X_test,y_train,y_test):train_score = []test_score = []for i in range(1,len(X_train)+1):algo.fit(X_train[:i],y_train[:i])y_train_predict = algo.predict(X_train[:i])train_score.append(mean_squared_error(y_train[:i],y_train_predict))y_test_predict = algo.predict(X_test)test_score.append(mean_squared_error(y_test,y_test_predict))plt.plot([i for i in range(1,len(X_train)+1)],np.sqrt(train_score),label = "train")plt.plot([i for i in range(1,len(X_train)+1)],np.sqrt(test_score),label = "test")plt.legend()plt.axis([0,len(X_train)+1,0,4])plt.show()

欠拟合

plot_learning_curve(LinearRegression(),X_train,X_test,y_train,y_test)

#多项式回归学习曲线
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
def PolynomialRegression(degree):return Pipeline([('poly',PolynomialFeatures(degree=degree)),('std_scaler',StandardScaler()),('lin_reg',LinearRegression())])

拟合

poly2_reg = PolynomialRegression(degree=2)
plot_learning_curve(poly2_reg,X_train,X_test,y_train,y_test)
#线性学习曲线误差稳定在1.5 多项式degree = 2时 误差稳定在1.2

过拟合

poly20_reg = PolynomialRegression(degree=20)
plot_learning_curve(poly20_reg,X_train,X_test,y_train,y_test)

Python机器学习:多项式回归与模型泛化005学习曲线相关推荐

  1. 05机器学习--多项式回归与模型泛化及python实现

    目录 ①什么是多项式回归 ②scikit-learn中的多项式回归和Pipelin ③过拟合与欠拟合 ④验证数据集与交叉验证 ⑤回顾网格搜索 ⑥偏差方差权衡 ⑦解决过拟合问题--模型正则化1--岭回归 ...

  2. Python机器学习:多项式回归与模型泛化004为什么需要训练数据集和测试数据集

    泛化能力:由此及彼能力 遇见新的拟合能力差 数据 #数据 import numpy as np import matplotlib.pyplot as plt x = np.random.unifor ...

  3. 多项式回归与模型泛化

    1.多项式回归 线性回归的局限性是只能应用于存在线性关系的数据中,但是在实际生活中,很多数据之间是非线性关系,虽然也可以用线性回归拟合非线性回归,但是效果会变差,这时候就需要对线性回归模型进行改进,使 ...

  4. python机器学习 | 多项式回归和拟合

    多项式回归和拟合.正则化 1 多项式回归 1.1 介绍 1.2 回归实现 2 拟合&正则化 2.1拟合问题 2.1.1 拟合出现的类型 2.2 解决拟合出现的问题 3 正则化 3.1 介绍 3 ...

  5. Python机器学习:多项式回归与模型泛化008模型泛化与岭回归

    岭回归 数据 #数据 import numpy as np import matplotlib.pyplot as plt np.random.seed(42) x = np.random.unifo ...

  6. Python机器学习:多项式回归与模型泛化007偏差方差平衡

    用名字预测成绩则会高偏差.. 高方差,泛化能力差!

  7. Python机器学习:多项式回归与模型泛化003过拟合与欠拟合

    过拟合欠拟合 #过拟合和欠拟合 import numpy as np import matplotlib.pyplot as plt x = np.random.uniform(-3,3,size=1 ...

  8. Python机器学习:多项式回归与模型泛化010L1L2和弹性网络

  9. Python机器学习:多项式回归与模型泛化009LASSO回归

    代码接着008 Lasso #LASSO from sklearn.linear_model import Lasso def LassoRegression(degree,alpha):return ...

最新文章

  1. 忽悠神经网络指南:教你如何把深度学习模型骗得七荤八素
  2. css3 自定义滚动条样式
  3. 【MySQL】深入浅出剖析mysql事务锁机制 - 笔记
  4. 发布dotNetCore程序到Kubernetes
  5. 三极管工作原理_三极管的基本工作原理,这个讲的很全
  6. 四款主流测试工具的测试流程
  7. python tensorflow 智能家居_用GPU加速深度学习: Windows安装CUDA+TensorFlow教程
  8. daemontools安装和使用
  9. 万年历c语言代码单链表,万年历的C语言实现
  10. C4D常用操作——挤压+倒角详解
  11. 两阶段最小二乘法原理_R语言工具变量与两阶段最小二乘法
  12. QPS,TPS,RPS你知道多少?
  13. 抖音 Android 性能优化系列:启动优化之理论和工具篇
  14. Tk/Tkx滚动条的使用
  15. 《最强大脑:魔方墙找茬王郑才千的学神秘笈-郑才千》-读书笔记
  16. [渝粤教育] 西安建筑科技大学 技术经济学 参考 资料
  17. 电压源和电流的关联参考方向_在大学《电路原理》中,电流源和电压源如何判断关联参考方向和非关联参考方向?...
  18. Guava 指南 之「前置条件」
  19. java实现mysql拦截_java分页拦截类实现sql自动分页
  20. 傻白探索Chiplet,关于EPYC Zen2 的一些理解记录(五)

热门文章

  1. java jsp js xml,JSP语法的xml写法
  2. datefromstring 转换不准确_免费的在线OCR工具,将图片内容转换为文本内容
  3. 繁体字_如何简单快速地批量认识繁体字?
  4. 聚焦核心竞争力:自建与外购
  5. 从helloworld回顾程序的编译过程之一
  6. C++开源矩阵计算工具——Eigen的简单用法(二)
  7. 屏幕空间环境光遮蔽(SSAO)算法的实现
  8. 程序员表白代码python_程序员python表白代码
  9. How GPUs Work
  10. leetcode695:DFS 岛屿最大面积(C语言)