非线性回归

目标

  1. 区分线性回归和非线性回归
  2. 用py实现非线性回归

如果数据表现出一个曲线的趋势,那么相比于非线性回归,线性回归就不会产生一个非常精确的结果,因为线性回归假设数据是线性的。就让我们通过一个例子学习一下非线性回归。在这篇博客中我们对中国1960年到2014年的GDP拟合了一个非线性模型。

导入相关库

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

虽然线性回归在有的数据集上表现的很好,但是它不能应用到所有数据集。首先我们回想一下线性回归,它是拟合因变量y和自变量x,方程是很简单的,比如y=2x+3y = 2x + 3y=2x+3。

下面用代码来看看。

x = np.arange(-5.0, 5.0, 0.1)##You can adjust the slope and intercept to verify the changes in the graph
y = 2*(x) + 3
y_noise = 2 * np.random.normal(size=x.size)
ydata = y + y_noise
#plt.figure(figsize=(8,6))
plt.plot(x, ydata,  'bo')
plt.plot(x,y, 'r')
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

非线性回归是用一种拟合自变量 xxx 和因变量yyy之间的非线性关系的一种方法。

比如下面是一个多项式。

y=ax3+bx2+cx+d\ y = a x^3 + b x^2 + c x + d \  y=ax3+bx2+cx+d

非线性函数有指数、对数、分数等元素

y=log⁡(x)y = \log(x)y=log(x)

也可以是这种复杂的形式
y=log⁡(ax3+bx2+cx+d)y = \log(a x^3 + b x^2 + c x + d)y=log(ax3+bx2+cx+d)

让我们看看这个三次函数

x = np.arange(-5.0, 5.0, 0.1)##You can adjust the slope and intercept to verify the changes in the graph
y = 1*(x**3) + 1*(x**2) + 1*x + 3
# 加点噪音,也就是随机数
y_noise = 20 * np.random.normal(size=x.size)
ydata = y + y_noise
plt.plot(x, ydata,  'bo')
plt.plot(x,y, 'r')
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

还有一些其他的形式

二次函数

Y=X2Y = X^2 Y=X2

x = np.arange(-5.0, 5.0, 0.1)##You can adjust the slope and intercept to verify the changes in the graphy = np.power(x,2)
y_noise = 2 * np.random.normal(size=x.size)
ydata = y + y_noise
plt.plot(x, ydata,  'bo')
plt.plot(x,y, 'r')
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

指数函数

Y=a+bcXY = a + b c^XY=a+bcX其中b≠0,c>0,c≠1其中b ≠0, c > 0 , c ≠1其中b​=0,c>0,c​=1

X = np.arange(-5.0, 5.0, 0.1)##You can adjust the slope and intercept to verify the changes in the graphY= np.exp(X)plt.plot(X,Y)
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

对数

y=log⁡(x)y = \log(x)y=log(x)

X = np.arange(-5.0, 5.0, 0.1)Y = np.log(X)plt.plot(X,Y)
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

Sigmoidal/Logistic

Y=a+b1+c(X−d)Y = a + \frac{b}{1+ c^{(X-d)}}Y=a+1+c(X−d)b​

X = np.arange(-5.0, 5.0, 0.1)Y = 1-4/(1+np.power(3, X-2))plt.plot(X,Y)
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

非线性回归的例子

我们将要拟合中国从1960年到2014年的GDP数据。我们下载的数据有两列,第一列是年份,从1960到2014,第二列是对应年份的国内生产总值(美元)。

gdp数据(点我下载(❁´◡`❁))

import numpy as np
import pandas as pd
df = pd.read_csv("china_gdp.csv")
df.head(10)
Year Value
0 1960 5.918412e+10
1 1961 4.955705e+10
2 1962 4.668518e+10
3 1963 5.009730e+10
4 1964 5.906225e+10
5 1965 6.970915e+10
6 1966 7.587943e+10
7 1967 7.205703e+10
8 1968 6.999350e+10
9 1969 7.871882e+10

数据可视化

数据有点像logistic或者指数函数,一开始增长得特别慢,从2005年开始,增长速度就非常显著了,在2010年代略微减速。

plt.figure(figsize=(8,5))
x_data, y_data = (df["Year"].values, df["Value"].values)
plt.plot(x_data, y_data, 'ro')
# plt.stem(x_data, y_data)
plt.ylabel('GDP')
plt.xlabel('Year')
plt.show()

选择模型

从第一眼看这个散点图,我就感觉logistic函数会不错,因为一开始增长很慢、中间增长很快、最后又慢了下来

就像下面这样:

X = np.arange(-5.0, 5.0, 0.1)
Y = 1.0 / (1.0 + np.exp(-X))plt.plot(X,Y)
plt.ylabel('Dependent Variable')
plt.xlabel('Independent Variable')
plt.show()

logsitic函数的方程如下

Y^=11+e−β_1(X−β_2)\hat{Y} = \frac1{1+e^{-\beta\_1(X-\beta\_2)}}Y^=1+e−β_1(X−β_2)1​

β_1\beta\_1β_1: 控制曲线的陡度,

β_2\beta\_2β_2: x轴上平移

构建模型

现在,让我们构建我们的回归模型并且初始化参数

def sigmoid(x, Beta_1, Beta_2):y = 1 / (1 + np.exp(-Beta_1*(x-Beta_2)))return y

首先随便搞两个参数

beta_1 = 0.10
beta_2 = 1990.0#logistic function
Y_pred = sigmoid(x_data, beta_1 , beta_2)#plot initial prediction against datapoints
plt.plot(x_data, Y_pred*15000000000000.)
plt.plot(x_data, y_data, 'ro')

我们的目标是找到最好的参数。

第一步把x和y都标准化一下。

# Lets normalize our data
xdata =x_data/max(x_data)
ydata =y_data/max(y_data)

如何找到拟合曲线最好的参数?

我们可以使用curve_fit,它使用非线性最小二乘来拟合我们的sigmoid函数。 优化参数值,使sigmoid(xdata, *popt) - ydata的残差平方和最小化。

Popt是我们的优化参数。

from scipy.optimize import curve_fit
popt, pcov = curve_fit(sigmoid, xdata, ydata)
#print the final parameters
print(" beta_1 = %f, beta_2 = %f" % (popt[0], popt[1]))
 beta_1 = 690.451711, beta_2 = 0.997207

现在画一下我们的回归模型。

x = np.linspace(1960, 2015, 55)
x = x/max(x)
plt.figure(figsize=(8,5))
y = sigmoid(x, *popt)
plt.plot(xdata, ydata, 'ro', label='data')
plt.plot(x,y, linewidth=3.0, label='fit')
plt.legend(loc='best')
plt.ylabel('GDP')
plt.xlabel('Year')
plt.show()

评估模型

虽然看上去不错,但是在运行过程中R2R^2R2有时竟是负的,而且就是R2R^2R2很大的时候在测试集上实际效果也不是很好,所以还是不很靠谱。

# 把数据分为训练集和测试集
msk=np.random.rand(len(df))<0.8
# print(msk)
train_x=xdata[msk]
test_x=xdata[~msk]
train_y=ydata[msk]
test_y=ydata[~msk]# 用训练集建立一个模型
popt,pcov=curve_fit(sigmoid,train_x,train_y)yyy = sigmoid(train_x, *popt)
plt.plot(train_x, train_y, 'ro', label='data')
plt.plot(train_x,yyy, linewidth=3.0, label='fit')
plt.plot(test_x, test_y, 'go', label='data')
# 在测试集上预测
y_hat=sigmoid(test_x,*popt)
print("test_x:",test_x*sum(df["Year"]))
print("test_y:",test_y*sum(df['Value']))
print("y_hat:",y_hat*sum(df['Value']))
# 评估
print("Mean absolute error: %.2f" % np.mean(np.absolute(y_hat - test_y)))
print("Residual sum of squares (MSE): %.2f" % np.mean((y_hat - test_y) ** 2))
from sklearn.metrics import r2_score
print("R2-score: %.2f" % r2_score(test_y,y_hat) )
test_x: [106463.34160874 106788.91757696 107005.96822244 107331.54419067107657.12015889 107765.64548163 108091.22144985 108254.00943396108308.27209533 108579.58540218 108688.11072493 109122.21201589]
test_y: [3.56342870e+11 5.34252726e+11 8.56103610e+11 1.13258445e+121.96991285e+12 2.28075199e+12 3.24347532e+12 5.58752067e+126.57072900e+12 1.01688021e+13 1.25937257e+13 5.71889161e+13]
y_hat: [3.13221434e+06 2.82948697e+07 1.22728310e+08 1.10865316e+091.00138972e+10 2.08527024e+10 1.87974610e+11 5.62290529e+118.08915481e+11 4.80484704e+12 9.38891854e+12 5.66862461e+13]
Mean absolute error: 0.03
Residual sum of squares (MSE): 0.00
R2-score: 0.96

【机器学习】python实现非线性回归(以中国1960-2014GDP为例)相关推荐

  1. python是中国的吗-使用Python实现画一个中国地图

    为什么是Python 先来聊聊为什么做数据分析一定要用Python或R语言.编程语言这么多种,Java, PHP都很成熟,但是为什么在最近热火的数据分析领域,很多人选择用Python语言? 数据分析只 ...

  2. 几行代码搞定ML模型,低代码机器学习Python库正式开源

    公众号关注 "视学算法" 设为 "星标",消息即可送达! 机器之心报道 机器之心编辑部 PyCaret 库支持在「低代码」环境中训练和部署有监督以及无监督的机器 ...

  3. python第三章上机实践_《机器学习Python实践》读书笔记-第三章

    <机器学习Python实践>,第三章,第一个机器学习项目 以往目录:橘猫吃不胖:<机器学习Python实践>读书笔记-第一章​zhuanlan.zhihu.com 书中介绍了一 ...

  4. anaconda3卸载python_机器学习Python编程环境:VSCode+Anaconda

    机器学习Python编程环境(Windows):VSCode+Anaconda 安装顺序:Anaconda ->VSCode (不必下载Python)->机器学习常用Python包 为什么 ...

  5. python的开发环境有哪些系统_Win10下配置机器学习python开发环境

    近期计划写一写机器学习微信小程序的开发教程,但微信开发工具只提供了Windows和Mac OS版本,作为一名长期使用Linux系统的开发人员,虽然始终认为Linux系统才是对开发者最友好的,但微信团队 ...

  6. python灰色模型代码_几行代码搞定ML模型,低代码机器学习Python库正式开源

    机器之心报道 机器之心编辑部 PyCaret 库支持在「低代码」环境中训练和部署有监督以及无监督的机器学习模型,提升机器学习实验的效率. 想提高机器学习实验的效率,把更多精力放在解决业务问题而不是写代 ...

  7. python低代码_几行代码搞定ML模型,低代码机器学习Python库正式开源

    PyCaret 库支持在「低代码」环境中训练和部署有监督以及无监督的机器学习模型,提升机器学习实验的效率. 想提高机器学习实验的效率,把更多精力放在解决业务问题而不是写代码上?低代码平台或许是个不错的 ...

  8. 代码实现tan graph model for classification_几行代码搞定ML模型,低代码机器学习Python库正式开源...

    PyCaret 库支持在「低代码」环境中训练和部署有监督以及无监督的机器学习模型,提升机器学习实验的效率. 想提高机器学习实验的效率,把更多精力放在解决业务问题而不是写代码上?低代码平台或许是个不错的 ...

  9. 机器学习 python 库_Python机器学习库

    机器学习 python 库 什么是机器学习? (What is Machine Learning?) As the web is immensely growing with each day, an ...

  10. 教你怎么用Python和Qt5编写中国象棋AI版——规则模块

    提示:该模块用于实现规则模块 教你怎么用Python和Qt5编写中国象棋AI版--规则模块 前言 一.中国象棋大致规则? 二.各棋子规则实现思路 1.兵 注意事项 过河兵合法偏移 未过河兵合法偏移 2 ...

最新文章

  1. Openstack组建部署 — Environment of Controller Node
  2. UA OPTI570 量子力学17 创生算符与湮灭算符
  3. Java Class Loader Retrospect
  4. Lua require 相对路径
  5. 面向对象的Oracle用法
  6. 数据结构c语言严4pdf,数据结构(C语言)严蔚敏 吴伟明 编著 04.pdf
  7. 大数据最核心的价值是什么?
  8. excel文件的工作表保护密码忘记了
  9. 5角星画法 android,Android Canvas绘制正多边形和正多角星
  10. 数据猿专访诸葛io孔淼:数据与业务“动态”结合才能发挥最大威力
  11. 苹果手机电池怎么保养_手机电池损耗检测,电池修复软件
  12. HTML 几种特别分割线特效 详细出处参考:http://www.jb51.net/web/28414.html
  13. vue项目实现大屏展示 自适应问题
  14. 计算机主板上一般带有高速缓冲存储器cache,它是与什么之间的缓存,计算机微机原理与应用(一)...
  15. 关于 simulink 的 1/z 模块是什么的问题
  16. 多线程 - 三种实现
  17. 初学C语言常见的错误
  18. python浮点数比较大小_浮点数的相等比较
  19. 计算机cpu风扇的结构,“电脑专家”教你如何拆cpu风扇【图文教程】
  20. Educoder计算机数据表示实验(HUST)-汉字国标码转区位码实验偶校验编码设计logisim

热门文章

  1. cad延伸命令怎么用_CAD缩放怎么用,CAD缩放图文教程
  2. 12张大数据图看看2016年世界各地发生大事件!
  3. ANC降噪耳机量产测试方案
  4. R语言:判断身份证号码真伪的函数编写
  5. Vue/ElementUI上传文件检验
  6. 发票识别 表格票据识别
  7. linux源代码安装apr,linux APR安装 APR-UTIL 安装 源码安装
  8. 计算机图片怎么截图快捷键,电脑系统截图快捷键(电脑怎么截图)
  9. java中,HashMap为什么每次扩容的倍数是2,而不是1.5或者2.5?
  10. 【USACO】 录制唱片