在某些情况下,结果和预测变量之间的真正关系可能不是线性的。
为了捕捉这些非线性效应,扩展线性回归模型(Chapter @ref(linear-regression))有不同的解决方案,其中包括:

  • Polynomial regression: 这是建立非线性关系的简单方法。它将多项式项或二次项(平方、立方体等)添加到回归中。
  • Spline regression: 用一系列多项式段拟合一条光滑曲线。划分spline段的值称为Knots
  • Generalized additive models (GAM): 拟合自动选择knotsspline模型。

在本章中,您将学习如何计算非线性回归模型以及如何比较不同的模型以选择适合您数据的最佳模型。

RMSE和R2指标将用于比较不同的模型(see Chapter @ref(linear regression)).

  • RMSE代表模型预测误差,这是观察到的结果值和预测结果值的平均差异。
  • R2表示观察到的和预测的结果值之间的平方相关性。

最好的模型是最低RMSE和最高R2的模型

Loading Required R packages

  • tidyverse for easy data manipulation and visualization
  • caret for easy machine learning workflow
library(tidyverse)
library(caret)
theme_set(theme_classic())

Preparing the data

我们将使用Boston数据集[in MASS package], 基于预测变量LSTA (percentage of lower status of the population),用于预测波士顿郊区的房屋价值中值(MDEV)
我们将将数据随机分为训练集(用于构建预测模型的80%)和测试集(评估模型的20%)。确保将种子设置为可重复性。

# Load the data
data("Boston", package = "MASS")
# Split the data into training and test set
set.seed(123)
training.samples <- Boston$medv %>%createDataPartition(p = 0.8, list = FALSE)
train.data  <- Boston[training.samples, ]
test.data <- Boston[-training.samples, ]

首先,可视化MEDV与LSTAT变量的散点图如下:

ggplot(train.data, aes(lstat, medv) ) +geom_point() +stat_smooth()

上面的散点图表明两个变量之间存在非线性关系

Linear regression {linear-reg}

标准线性回归模型方程可以写为MEDV = B0 + B1*LSTAT
计算线性回归模型:

# Build the model
model <- lm(medv ~ lstat, data = train.data)
# Make predictions
predictions <- model %>% predict(test.data)
# Model performance
data.frame(RMSE = RMSE(predictions, test.data$medv),R2 = R2(predictions, test.data$medv)
)
##   RMSE    R2
## 1 6.07 0.535

可视化数据:

ggplot(train.data, aes(lstat, medv) ) +geom_point() +stat_smooth(method = lm, formula = y ~ x)

Polynomial regression

多项式回归在回归方程中添加多项式或二次项,如下:
medv=b0+b1∗lstat+b2∗lstat2medv = b0+b1*lstat+b2*lstat^2 medv=b0+b1∗lstat+b2∗lstat2

在r中,要创建一个预测变量x^2,您应该使用函数I(),如下:I(x^2)。把 x 提高到2的幂次方
多项式回归可以在R中计算如下:

lm(medv ~ lstat + I(lstat^2), data = train.data)

另一种简单的解决方案是使用以下方式:

lm(medv ~ poly(lstat, 2, raw = TRUE), data = train.data)
## Call:
## lm(formula = medv ~ poly(lstat, 2, raw = TRUE), data = train.data)
##
## Coefficients:
##                 (Intercept)  poly(lstat, 2, raw = TRUE)1
##                     42.5736                      -2.2673
## poly(lstat, 2, raw = TRUE)2
##                      0.0412

该输出包含与LSTAT相关的两个系数:一个用于线性项 (lstat1),一个用于二次项(lstat2)。

以下示例计算六阶多项式拟合:

lm(medv ~ poly(lstat, 6, raw = TRUE), data = train.data) %>%summary()
# # Call:
# #   lm(formula = medv ~ poly(lstat, 6, raw = TRUE), data = train.data)
# #
# # Residuals:
# #   Min       1Q   Median       3Q      Max
# # -13.1962  -3.1527  -0.7655   2.0404  26.7661
# #
# # Coefficients:
# #   Estimate Std. Error t value Pr(>|t|)
# # (Intercept)                  7.788e+01  6.844e+00  11.379  < 2e-16 ***
# #   poly(lstat, 6, raw = TRUE)1 -1.767e+01  3.569e+00  -4.952 1.08e-06 ***
# #   poly(lstat, 6, raw = TRUE)2  2.417e+00  6.779e-01   3.566 0.000407 ***
# #   poly(lstat, 6, raw = TRUE)3 -1.761e-01  6.105e-02  -2.885 0.004121 **
# #   poly(lstat, 6, raw = TRUE)4  6.845e-03  2.799e-03   2.446 0.014883 *
# #   poly(lstat, 6, raw = TRUE)5 -1.343e-04  6.290e-05  -2.136 0.033323 *
# #   poly(lstat, 6, raw = TRUE)6  1.047e-06  5.481e-07   1.910 0.056910 .
# # ---
# #   Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
# #
# # Residual standard error: 5.188 on 400 degrees of freedom
# # Multiple R-squared:  0.6845,    Adjusted R-squared:  0.6798
# # F-statistic: 144.6 on 6 and 400 DF,  p-value: < 2.2e-16

从上面的输出可以看出,超出第五阶以上的多项式项并不重要。因此,只需创建第五个多项式回归模型如下:

# Build the model
model <- lm(medv ~ poly(lstat, 5, raw = TRUE), data = train.data)
# Make predictions
predictions <- model %>% predict(test.data)
# Model performance
data.frame(RMSE = RMSE(predictions, test.data$medv),R2 = R2(predictions, test.data$medv)
)
##       RMSE        R2
## 1 5.270374 0.6829474

可视化第五多项式回归线,如下:

ggplot(train.data, aes(lstat, medv) ) +geom_point() +stat_smooth(method = lm, formula = y ~ poly(x, 5, raw = TRUE))

Log transformation

当您有非线性关系时,您也可以尝试对预测变量的对数转换:

# Build the model
model <- lm(medv ~ log(lstat), data = train.data)
# Make predictions
predictions <- model %>% predict(test.data)
# Model performance
data.frame(RMSE = RMSE(predictions, test.data$medv),R2 = R2(predictions, test.data$medv)
)
##      RMSE        R2
## 1 5.467124 0.6570091

可视化数据:

ggplot(train.data, aes(lstat, medv) ) +geom_point() +stat_smooth(method = lm, formula = y ~ log(x))

Spline regression

多项式回归仅在非线性关系中捕获一定数量的曲率。建模非线性关系的一种替代方法是使用splines (P. Bruce and Bruce 2017).
Splines提供一种在固定点之间平稳插值的方法,称为knots。多项式回归是在knots之间计算的。换句话说,splines是一系列多项式段串在一起,加入knots (P. Bruce and Bruce 2017)。

R软件包splines包括用于在回归模型中创建b-spline项的函数bs
您需要指定两个参数:the degree of the polynomialthe location of the knots。在我们的示例中,我们将knots放在下四分位数,中值四分位数和上四分位数。

knots <- quantile(train.data$lstat, p = c(0.25, 0.5, 0.75))

我们将使用立方spline(degree= 3)创建模型:

library(splines)
# Build the model
knots <- quantile(train.data$lstat, p = c(0.25, 0.5, 0.75))
model <- lm (medv ~ bs(lstat, knots = knots), data = train.data)
# Make predictions
predictions <- model %>% predict(test.data)
# Model performance
data.frame(RMSE = RMSE(predictions, test.data$medv),R2 = R2(predictions, test.data$medv)
)
##   RMSE    R2
## 1 4.97 0.688

请注意,spline术语的系数是不可解释的。
将三次spline曲线可视化如下:

ggplot(train.data, aes(lstat, medv) ) +geom_point() +stat_smooth(method = lm, formula = y ~ splines::bs(x, df = 3))

Generalized additive models

一旦您发现数据中的非线性关系,多项式项可能不足以捕获这种关系,并且spline项需要指定knots
Generalized additive models(GAM)是一种自动拟合spline回归的技术。这可以使用mgcv R package:

library(mgcv)
# Build the model
model <- gam(medv ~ s(lstat), data = train.data)
# Make predictions
predictions <- model %>% predict(test.data)
# Model performance
data.frame(RMSE = RMSE(predictions, test.data$medv),R2 = R2(predictions, test.data$medv)
)
##   RMSE    R2
## 1 5.02 0.684

s(lstat) 告诉gam() 函数,以找到spline的“最佳”knots
可视化数据:

ggplot(train.data, aes(lstat, medv) ) +geom_point() +stat_smooth(method = gam, formula = y ~ s(x))

Comparing the models

从分析不同模型的RMSE和R2指标,可以看出,多项式回归,spline回归和generalized additive models 的表现优于线性回归模型和对数转换方法。

reference

  • http://www.sthda.com/english/articles/40-regression-analysis/162-nonlinear-regression-essentials-in-r-polynomial-and-spline-regression-models
  • Bruce, Peter, and Andrew Bruce. 2017. Practical Statistics for Data Scientists. O’Reilly Media.

R语言实现非线性回归相关推荐

  1. R语言之非线性回归xt9.5

    第9章 非线性回归 9.5 表9-13数据中GDP和投资额K都是用定基居民消费价格指数CPI缩减后的值,1978年的价格指数为100.(C-D生产函数y=AKαLβ) (1)用线性化的乘性误差项模型拟 ...

  2. R语言处理非线性回归模型C-D方程,使用R语言进行多项式回归、非线性回归模型曲线拟合...

    对于线性关系,我们可以进行简单的线性回归.对于其他关系,我们可以尝试拟合一条曲线.曲线拟合是构建一条曲线或数学函数的过程,它对一系列数据点具有最佳的拟合效果. 使用示例数据集#我们将使Y成为因变量,X ...

  3. r语言做断轴_R语言用nls做非线性回归以及函数模型的参数估计

    非线性回归是在对变量的非线性关系有一定认识前提下,对非线性函数的参数进行最优化的过程,最优化后的参数会使得模型的RSS(残差平方和)达到最小.在R语言中最为常用的非线性回归建模函数是nls,下面以ca ...

  4. 利用R语言进行线性/非线性回归拟合实例(1)

    利用R语言进行线性/非线性回归拟合实例(1) 1. 生成一组数据 vector<float>xxvec; vector<float>yyvec; ofstreamfout(&q ...

  5. R语言曲线回归:多项式回归、多项式样条回归、非线性回归数据分析

    最近我们被客户要求撰写关于曲线回归的研究报告,包括一些图形和统计输出.本文将使用三种方法使模型适合曲线数据:1)多项式回归:2)用多项式样条进行B样条回归:3) 进行非线性回归.在此示例中,这三个中的 ...

  6. R 回归 虚拟变量na_工具amp;方法 | R语言机器学习包大全(共45个包)

    机器学习,是一门多学科交叉的人工智能领域的分析技术,它使用算法解析数据,从中学习,然后对世界上的某件事情做出决定或预测. 目前,常见机器学习的研究方向主要包括决策树.随机森林.神经网络.贝叶斯学习和支 ...

  7. r语言在java中的实现_R语言在现实中的应用

    R语言在现实中的应用有哪些?主要有以下几种 - 1.数据科学 "哈佛商业评论"将数据科学家命名为"21世纪最性感的工作". Glassdoor将其命名为2016 ...

  8. r语言 plot_R和Python的特点对比,这样你就知道该怎么选择了

    数据科学界有三大宝: Python.SAS和R,不过像SAS这种高端物种,不是我们这些平民能供养得起的啊. 根据 IEEE Spectrum的最新排名,R和Python仍然是最热门的数据科学编程语言. ...

  9. 数据科学r语言_您应该为数据科学学习哪些语言?

    数据科学r语言 Data science is an exciting field to work in, combining advanced statistical and quantitativ ...

  10. R语言学习笔记(六)回归分析

    文章目录 写在前面 普通最小二乘(OLS)回归法 正态假设 简单线性回归 多项式回归 多元线性回归 有交互项的多元线性回归 小结 回归诊断 标准方法 综合验证方法 多重共线性 广义线性回归--Logi ...

最新文章

  1. windows下安装awstats来分析apache的访问日志
  2. varnish安装及简单配置
  3. [简单题]Counting Duplicates( Python 实现)
  4. php反射机制详解,PHP反射机制实现插件的可插拔设计
  5. LeetCode 1670. 设计前中后队列(deque)
  6. 系统集成项目管理工程师考试复习-Part3
  7. 最新谷歌本地搜索api
  8. 利用selenium获取接口数据
  9. 提高电火炉使用安全,微波雷达模组感应控制,雷达感应技术方案
  10. B站在​港交所双重主要上市 陈睿:将扩大我们投资者基础
  11. 微软家庭服务器,微软下一代Windows家庭服务器Vail初印象
  12. antd vue表单验证_解决antd 表单设置默认值initialValue后验证失效的问题
  13. 【论文阅读】3D-CVF: Generating Joint Camera and LiDAR Features Using Cross-View Spatial Feature Fusion for
  14. 【网络工程】计算机网络专业术语概论全面整理
  15. 周志华机器学习(一)
  16. GIT修改账号密码重新登录和保存密码
  17. 华硕笔记本计算机名称,华硕NB是如何命名的?5招教你看清楚
  18. 【代码复用之】登录注册原生代码
  19. java特种兵 怎么样_Java特种兵(上册)
  20. 2022.1.17 学习笔记 (SPN中业务是如何传输的,主要是业务切片的调度编排)

热门文章

  1. 为什么要用代理服务器?
  2. 以太网芯片MAC和PHY的关系
  3. sai 绘图软件快捷键
  4. Elasticsearch创建索引别名
  5. 11 - JavaScript原型对象
  6. 我看韩寒-话题2010读后
  7. echo和narcissus寓意_【故事】三毛的英文名Echo,有什么含义?
  8. 英文论文评审意见_sci英文论文审稿意见怎么写(7)
  9. TSDB在高速公路大数据平台的应用
  10. 论文篇-----高速公路交通流数据质量控制及评价方法