# 线性回归、Ridge、LASSO、ElasticNet回归
import numpy as np
import pandas as pd
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.linear_model import LinearRegression, RidgeCV, LassoCV, ElasticNetCV# 设置随机数种子
np.random.seed(0)# 构造数据
def makedata():N = 9  # 9个点x = np.linspace(0, 6, N) + np.random.randn(N)  # 0~6等间隔数加上高斯噪声x = np.sort(x)  # 排序y = x ** 2 - 4 * x - 3 + np.random.randn(N)# 设置成列向量x = x.reshape(-1, 1)y = y.reshape(-1, 1)return x, y# 计算xss
def xss(y, y_hat):# 将数组展平y = y.ravel()y_hat = y_hat.ravel()# R2tss = np.sum(np.power(y - np.average(y, axis=0), 2))rss = np.sum(np.power(y_hat - y, 2))ess = np.sum(np.power(y_hat - np.average(y, axis=0), 2))R2 = (tss - rss) / tss# 添加到列表tss_list.append(tss)rss_list.append(rss)ess_list.append(ess)ess_rss_list.append(rss + ess)# 皮尔逊相关系数corr_coef = np.corrcoef(y, y_hat)[0, 1]return R2, corr_coefif __name__ == '__main__':# 消除警告warnings.filterwarnings(action='ignore')# 设置输出样式——精度(不用科学计数法,用小数点来显示)、显示宽度np.set_printoptions(suppress=True, linewidth=1000)# 获取数据N = 9X, y = makedata()# 模型models = [Pipeline([('poly', PolynomialFeatures()),('linear', LinearRegression(fit_intercept=False))]),Pipeline([('poly', PolynomialFeatures()),('linear', RidgeCV(alphas=np.logspace(-3, 2, 10), normalize=[True, False]))]),Pipeline([('poly', PolynomialFeatures()),('linear', LassoCV(alphas=np.logspace(-3, 2, 10), normalize=[True, False]))]),Pipeline([('poly', PolynomialFeatures()),('linear', ElasticNetCV(alphas=np.logspace(-3, 2, 10), l1_ratio=np.linspace(0.1, 1, 10)))])]# 画图mpl.rcParams['font.sans-serif'] = [u'simHei']mpl.rcParams['axes.unicode_minus'] = Falsefig = plt.figure(figsize=(18, 12), facecolor='w')# 阶数的数组d_pool = np.arange(1, 9, 1)m = d_pool.size# 设置渐变色clrs = []for i in np.linspace(16711680, 255, m):c = int(i)clrs.append('#%06x' % c)# 设置线宽line_width = np.linspace(5, 2, m)# 设置标题titles = u'线性回归', u'Ridge回归', u'LASSO回归', u'ElasticNet回归'tss_list = []rss_list = []ess_list = []ess_rss_list = []# 迭代画4个图for t in range(4):model = models[t]plt.subplot(2, 2, t + 1)plt.plot(X, y, 'ro', ms=10,zorder =N)for i, d in enumerate(d_pool):# 设置参数model.set_params(poly__degree=d)# 训练model.fit(X, y.ravel())# 获取参数linear = model.get_params('linesr')['linear']output = '%s : %d阶,系数为' % (titles[t], d)# 判断linear中是否有这个属性if hasattr(linear, 'alpha_'):idx = output.find('系数')output = output[:idx] + ('alpha = %.6f' % linear.alpha_) + output[idx:]if hasattr(linear, 'l1_ratio_'):idx = output.find('系数')output = output[:idx] + ('l1_ratio = %.6f' % linear.l1_ratio_) + output[idx:]print(output, linear.coef_.ravel())# 预测x_hat = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)y_hat = model.predict(x_hat)s = model.score(X, y)r2, corr_coef = xss(y, model.predict(X))# print('r2和相关系数:', r2, corr_coef)if d == 2:z = N - 1else:z = 0label = '%d阶,$R^2$=%.3f'%(d,r2)if hasattr(linear,'l1_ratio_'):label += 'l1_ratio = %.2f'%(linear.l1_ratio_)# 画图plt.plot(x_hat,y_hat,color=clrs[i],label=label,lw = line_width[i],zorder=z)plt.legend(loc='best')plt.grid(True)plt.title(titles[t],fontsize=18)plt.xlabel('X',fontsize=15)plt.ylabel('Y',fontsize=15)plt.tight_layout()plt.suptitle('多项式曲线拟合比较',fontsize=22)plt.show()

TSS >= RSS + ESS

y_max = max(max(tss_list), max(ess_rss_list)) * 1.05plt.figure(facecolor='w', figsize=(9, 7))t = np.arange(len(tss_list))  # 样本编号plt.plot(t, tss_list, 'ro-', lw=2, label='TSS')plt.plot(t, ess_list, 'mo-', lw=1, label='ESS')plt.plot(t, rss_list, 'bo-', lw=1, label='RSS')plt.plot(t, ess_rss_list, 'go-', lw=2, label='ESS + RSS')plt.legend(loc='best')plt.title('TSS >= ESS + RSS')plt.xlabel('样本编号')plt.ylabel('XSS的值', fontsize=15)plt.grid(True)plt.show()

1.5 案例:多项式曲线拟合的比较相关推荐

  1. Apache Commons Math3学习笔记(2) - 多项式曲线拟合(转)

    多项式曲线拟合:org.apache.commons.math3.fitting.PolynomialCurveFitter类. 用法示例代码: [java] view plain copy   // ...

  2. 最小二乘法多项式曲线拟合原理与实现--转

    原文地址:http://blog.csdn.net/jairuschan/article/details/7517773/ 概念 最小二乘法多项式曲线拟合,根据给定的m个点,并不要求这条曲线精确地经过 ...

  3. 今天开始学模式识别与机器学习Pattern Recognition and Machine Learning 书,章节1.1,多项式曲线拟合(Polynomial Curve Fitting)

    转载自:http://blog.csdn.net/xbinworld/article/details/8834155 Pattern Recognition and Machine Learning ...

  4. 多项式曲线拟合最小二乘法

    对给定的试验数据点(xi,yi)(i=1,2,--,n),可以构造m次多项式 数据拟合的最简单的做法就是使误差p(xi)-yi的平方和最小 当前任务就是求一个P(x)使得 从几何意义上讲就是寻求给与定 ...

  5. PRML(1)--绪论(上)多项式曲线拟合、概率论

    PRML绪论 1.1 多项式曲线拟合 1.1.1 问题描述 1.1.2 最小化平方和误差 1.1.3 多项式阶数确定 1.1.4 有趣问题--高阶模型为什么效果不好 1.1.4 数据集规模对模型的影响 ...

  6. Python04 直线拟合 多项式曲线拟合 指数曲线拟合(附代码)

    1. 实验结果 (1)在定义的类中设置已知的函数值列表为: (2)在 test.py 中选择直线拟合: 输出:拟合的直线函数及图像: (3)选择多项式曲线拟合: 输入:多项式拟合函数的次数: 输出:拟 ...

  7. 最小二乘法多项式曲线拟合数学原理及其C++实现

    目录 0 前言 1 最小二乘法概述 2 最小二乘法求解多项式曲线系数向量的数学推导 2.1 代数法 2.2 矩阵法 3 代码实现 4 总结 参考 0 前言 自动驾驶开发中经常涉及到多项式曲线拟合,本文 ...

  8. 多项式曲线拟合之最小二乘法推导

    1.多项式曲线拟合之最小二乘法 1.1 问题来源 1801年,意大利天文学家朱赛普·皮亚齐发现了第一颗小行星谷神星.经过40天的跟踪观测后,由于谷神星运行至太阳背后,使得皮亚齐失去了谷神星的位置.随后 ...

  9. Matlab 多项式曲线拟合polyfit

    polyfit                多项式曲线拟合 常见语法                 a = polyfit ( x, y, n) 说明               a=polyfi ...

  10. PRML 1.1 多项式曲线拟合

    PRML 1.1 多项式曲线拟合 输入 训练集 x ≡ ( x 1 , . . . , x N ) T x\equiv (x_1,...,x_N)^T x≡(x1​,...,xN​)T t ≡ ( t ...

最新文章

  1. 断点续传---多线程下载进阶(一)
  2. 织梦最新版后台一键更新网站、更新文档HTML卡死的解决方法
  3. 自学python需要安装什么软件-零基础入门Python怎么学习?老男孩python用什么软件...
  4. mac怎么用python2和3_Mac同时安装python2和python3
  5. python提取文章中的中文数字
  6. Java SSM篇3——Mybatis
  7. JavaFX 新WebService客户端脚本语言
  8. SpringBoot2.1.5 (24): @SpringBootTest单元测试
  9. 用C++获取屏幕上某点的颜色
  10. 微信小程序富文本三种方法+0.1rich-text+0.2插件+0.3webview
  11. C# List 深复制
  12. c语言中变量后减号大于号,大于等于运算符.ppt
  13. 计算机学科 集体备课记录,信息技术学科组集体备课活动记录
  14. Carhart四因子模型实用攻略
  15. 字节跳动员工晒出税后工资,网友:怀疑你是日薪
  16. Semi-Supervised Semantic Image Segmentation with Self-correcting Networks:基于自校正网络的半监督语义图像分割
  17. 可以修饰的基团有:氨基类,NHBOC类,Fmoc类不等,DSPE-PEG7-Mal
  18. 说说程序员不解风情的瞬间
  19. Unity3D游戏开发案例学习——Tanks!(基本完结)
  20. 雷军演讲刷屏,我对项目经理人的发展又有了2点想法……

热门文章

  1. php学习分享心得吧
  2. PA 2011 Round 3 prz题解
  3. 树组件:主要配置项、属性、方法
  4. 使用SVG绘制湖南地图
  5. LoadRunner去除事物中的程序的执行时间
  6. sdut Message Flood(c++ map)
  7. Java String的内存机制
  8. 使用 jQuery 和 KnockoutJS 开发在线股票走势图应用
  9. [.NET] : 使用自定义对象当作报表数据源
  10. input submit标签的高度和宽度与input text的差异