作者:量化小白一枚,上财研究生在读,偏向数据分析与量化投资

个人公众号:量化小白上分记

本文主要是对机器学习算法误差的分解,全文包括理论推导和模拟两部分。

1. 理论推导

如何评价机器学习算法的性能,是一个非常重要的问题,目前已有很多方法,基本思路都是用样本误差去估计泛化误差,简单的有将样本分为测试集和训练集两部分,复杂的包括交叉验证和Boostrap等方法,这其中一个很重要的思想是,避免测试样本在训练样本中出现,否则得到的会是一个偏向乐观的结果。

本文不过多论述这方面的内容,而是阐述另一个话题,误差的来源和分解,通过偏差-方差分解的办法。

这里我们使用西瓜书中的符号说明,

学习算法的期望预测可以表示为

学习算法总误差(期望泛化误差)用均方误差衡量时可以表示为

所谓偏差,是指真实值与预测值之间的差别,它刻画的是算法本身的拟合能力,用均方误差衡量如下

所谓方差,是指用不同的训练集进行训练,对同一测试集进行测试时,得到结果中误差序列的方差,这些训练集都来自同一个分布,即整体,它刻画的是数据扰动对于结果的影响

此外,模型还存在一定的噪声,可以表示为

通过简单推导可以得出如下结论

也就是泛化误差可以分解为偏差、方差和噪声的和

一般来说,方差和偏差是有冲突的,他们之间存在trade-off,无法减小方差的同时减小偏差,他们的关系可以表示为下图

这其中,训练程度是可以通过调参,增加特征进行控制的。 训练程度不足时,模型的拟合程度不够,偏差较大,训练数据的变化不会导致结果出现显著的变化,这时偏差主导了泛化误差率,随着训练程度加深,模型拟合程度加大,偏差逐渐减小,但此时对于数据的依赖性变强,数据的任一微小变动都可能导致结果发生巨大变化,可能出现过拟合现象。此时,方差主导的泛化误差率。

2. 模拟

首先说明,模拟部分使用的软件是R语言,不是PYTHON

实证部分我们尝试复制上面图中的偏差、方差关系示意图,案例来自ESL,先放上书中的标准图,毕竟这个看上去比较完美,我复制出来的结果没有这个好。

首先解释下这个图,图中浅红色线为用来自同一分布的不同训练集训练的模型对同一测试集预测结果的误差,浅蓝色线为对用于训练的训练集预测结果的误差,浅红色线和浅蓝色线分别是100条,即有100个训练集。深红色深蓝色线为浅色线的平均,用均值作为期望的估计量。

横轴表示的是训练程度,这里实际上是自变量的个数,数据中与因变量相关的自变量共有35个,每次训练分别使用1到35个变量进行训练,得到不同自变量数下的误差,就可以得到一条误差曲线。

可以看出,(测试集)红色线存在明显的Bias Variance Trade-Off,训练集(蓝色线)随着自变量个数增加,误差不断减小,最后实际上出现过了拟合,也就是之前说到的乐观结果。

再说下数据的生成

训练集:100个训练集,每个训练集中设置50个样本,每个样本有35个自变量,自变量均来自标准正态分布,因变量取值为:如果所有自变量加起来大于0,就是1,不大于0,就是0

测试集:1个测试集,50个样本,规则与训练集相同,代码如下

1set.seed(1)2test_size <- 503sigma <- 0.54test_x <- matrix(rnorm(test_size*35,0,1),test_size)5test_y <- ifelse(apply(test_x,1,sum)>0 , 1 , 0)

训练用的模型是Lasso,这个没什么需要说明的,R语言的glmnet包可以直接做。

 1pnum <- 35 2train_error_all <- matrix(NaN,100,pnum) 3test_error_all <- matrix(0,100,pnum) 4for( i in (1:100)){ 5 6  flag = 1 7  while (flag  == 1){ 8    train_x <- matrix(rnorm(50*35,0,1),nrow = 50) 9    train_y <- ifelse(apply(train_x,1,sum)>0 , 1 , 0)1011    lasso <- glmnet(train_x,train_y,alpha = 1, nlambda = 10000, family = 'gaussian',pmax = pnum)12    lambdas <- data.frame(df = lasso$df,lambda = lasso$lambda)13    ld <- aggregate(lambdas,by = list(lambdas$df),mean)1415    ld <-  ld[-1,]16    if(dim(ld)[1] == pnum){17      flag = 018    }19  }20212223  lasso1 <- glmnet(train_x,train_y,alpha = 1, lambda = ld$lambda, family = 'gaussian')24  result_train <- predict(lasso1, newx = train_x,type = 'response',s = ld$lambda)25  result_test <- predict(lasso1, newx = test_x,type = 'response',s = ld$lambda)262728  train_error <- apply(abs(result_train - train_y),2,mean)29  test_error <- apply(abs(result_test - test_y),2,mean)303132  train_error_all[i,ld$df] <- train_error33  test_error_all[i,ld$df] <- test_error3435print(i)36}373839train_error_mean <- apply(train_error_all,2,mean) 40test_error_mean <- apply(test_error_all,2,mean)  

算出来之后作图,最终效果图如下

动图是用animation、ggplot包做的,也是折腾了很久,感觉以后有时间可以专门写篇文章怎么用r语言做动图了。

 1g1 <- ggplot()  2saveGIF({ 3  for (i in 1:100){ 4    print(i) 5    train_data <- data.frame(train_error_all[i,]) 6    train_data$num <- 1:35 7    names(train_data) <- c('error','num') 8    train_data$type <- 'train_error' 91011    test_data <- data.frame(test_error_all[i,])12    test_data$num <- 1:3513    names(test_data) <- c('error','num')14    test_data$type <- 'test_error'1516    train_all <- data.frame(train_error)17    train_all$num <- 1:3518    names(train_all) <- c('error','num')19    train_all$type <- 'average(train_error)'2021    test_all <- data.frame(test_error)22    test_all$num <- 1:3523    names(test_all) <- c('error','num')24    test_all$type <- 'average(test_error)'25262728    g1 <- g1  + geom_line(data = train_data,aes(x=num,y=error),lwd = 1,colour = 'lightblue') +29                geom_line(data = test_data,aes(x=num,y=error),lwd = 1,colour = 'lightpink') +30                geom_line(data = train_all,aes(x=num,y=error),lwd = 2,colour = 'blue') +31                geom_line(data = test_all,aes(x=num,y=error),lwd = 2,colour = 'red') 3233    print(g1)34  }35},movie.name='Bias-Variance-Trade-Off.gif',interval=0.5,ani.width=700,ani.height=600)

再来一个静态版的,这个就简单多了

 1for (i in 1:100){ 2  plot(1:pnum,train_error_all[i,],xlab = '',ylab = '',xlim = c(0,pnum),ylim = c(0,0.6), 3       type = 'l',col = 'lightblue') 4  par(new = T) 5  plot(1:pnum,test_error_all[i,],xlab = '',ylab = '',xlim = c(0,pnum),ylim = c(0,0.6), 6       type = 'l',col = 'lightpink') 7 8  par(new = T) 9}1011plot(ld$df,train_error,xlab = '',ylab = '',xlim = c(0,pnum),ylim = c(0,0.6),type = 'l',col = 'blue',lwd = 2)12par(new = T)13plot(ld$df,test_error,xlab = 'Model Complexity (df)',ylab = 'Prediction Error',xlim = c(0,pnum),ylim = c(0,0.6),14     type = 'l',col = 'red',lwd = 2)15box()

整体来看,跟书上的趋势是差不多的,但还是有一些细微的差别,比如书上图最左侧y轴在1,我做的在0.5,因为ESL里没有具体说明因变量是怎么定义的,我是按照后面一个例子的方式定义的,所以有差别,但不影响理解。

最后需要说明两点:

1. 虽然之前提到了方差-偏差分解,但模拟过程中其实并没有用到,计算的是总的误差,只是为了分析方便,下一篇会通过方差偏差分解来更细致分析误差。

2.之前提到的方差-偏差分解并不一定成立,只有在用均方误差度量模型误差时才成立,如果使用0-1误差等其他方法,就不再成立。

参考文献

1. Ruppert D. The Elements of Statistical Learning: Data Mining, Inference, and Prediction[J]. Journal of the Royal Statistical Society, 2010, 99(466):567-567.

2. 周志华. 机器学习 : = Machine learning[M]. 清华大学出版社, 2016.


公众号后台回复关键字即可学习

回复 爬虫            爬虫三大案例实战  
回复 Python       1小时破冰入门

回复 数据挖掘     R语言入门及数据挖掘
回复 人工智能     三个月入门人工智能
回复 数据分析师  数据分析师成长之路 
回复 机器学习      机器学习的商业应用
回复 数据科学      数据科学实战
回复 常用算法      常用数据挖掘算法

R语言模拟:Bias Variance Trade-Off相关推荐

  1. R语言模拟疫情传播-gganimate包

    本文用gganimate包展示模拟疫情数据 本文篇幅较长,分为以下几个部分: 前言 效果展示 小结 附录:代码 前言 前文<R语言模拟疫情传播-RVirusBroadcast>已经介绍了一 ...

  2. R语言-模拟产生统计专业学生的成绩

    现在Mayuyu会以一个例子来说明R语言在统计学中的应用.模拟一个高中学生语数外三科的成绩单. 首先认识两个重要的函数,source()和print(),source()函数是用来运行R脚本的,一个R ...

  3. R语言中文社区2018年终文章整理(作者篇)

    欢迎关注天善智能,我们是专注于商业智能BI,人工智能AI,大数据分析与挖掘领域的垂直社区,学习,问答.求职一站式搞定! 对商业智能BI.大数据分析挖掘.机器学习,python,R等数据领域感兴趣的同学 ...

  4. r语言数据变量分段_R数据分析:用R语言做meta分析

    这里以我的一篇meta分析为例,详细描述meta分析的一般步骤,该例子实现的是效应量β的合并 R包:metafor或meta包,第一个例子以metafor包为例. 1.准备数据集 2.异质性检验 in ...

  5. 语言模拟蒲丰问题_R语言小数定律的保险业应用:泊松分布模拟索赔次数

    原文链接: 拓端数据科技 / Welcome to tecdat​tecdat.cn 在保险业中,由于分散投资,通常会在合法的大型投资组合中提及大数定律.在一定时期内,损失"可预测" ...

  6. 多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

  7. 多元线性回归分析c语言,多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

  8. R语言数据统计分析的基本函数

    转载于http://blog.csdn.net/yujunbeta/article/details/8547171 在面对大规模数据时,对数据预处理,获取基本信息是十分必要的.今天分享的就是数据预处理 ...

  9. R语言与数据的预处理

    在面对大规模数据时,对数据预处理,获取基本信息是十分必要的.今天分享的就是数据预处理的一些东西. 一.获取重要数据 在导入大规模数据时,我们通常需要知道数据中的关键内容:最值,均值,离差,分位数,原点 ...

  10. R语言实用案例分析-1

    在日常生活和实际应用当中,我们经常会用到统计方面的知识,比如求最大值,求平均值等等.R语言是一门统计学语言,他可以方便的完成统计相关的计算,下面我们就来看一个相关案例. 1. 背景 最近西安交大大数据 ...

最新文章

  1. windows 软件安装事件_苹果安装windows,报windows支持软件未能存储到所选驱动器
  2. ASIA TODAY 英文版
  3. java tabpanel_java cs tab点击切换标签的实现 panel
  4. 风变编程python第一关脸黑怪我喽_风变编程:Python适合编程初学者学习吗?
  5. centos压缩和解压缩
  6. c语言每个整数占9列,c语言 第五章 数据类型和表达式.ppt
  7. java servlet练习测试
  8. (转)Asp.net 中 Get和Post 的用法
  9. Hive 分区表操作 创建、删除
  10. mysql导入工具 行提交_使用命令行工具mysqlimport导入数据
  11. HDU 5691 ——Sitting in Line——————【状压动规】
  12. ibatis mysql_mysql +ibatis
  13. GIL锁,线程锁(互斥锁)和递归锁
  14. 干货 | 语音识别类产品细分及其应用场景
  15. c语言编程发展史详细介绍,一张图让你了解编程语言发展史
  16. 钉钉总裁不穷:周末最烦写周报还有被人钉
  17. JavaScript学习手册十一:JSON
  18. QPushButton中clicked消息参数一直返回false问题解决方法
  19. win10设置护眼模式
  20. 707-详解32位Linux系统内存地址映射

热门文章

  1. 支付宝“跑路”,一亿用户服务彻底关停!
  2. 我的一个低级错误,导致数据库崩溃半小时!!
  3. 经过一年的煎熬,我们还是决定把系统升级成基于 Spring Cloud 的微服务架构
  4. 我敢说,你的登录接口肯定不安全
  5. 来自微信官方:微信支付跨平台软件架构首次曝光
  6. 详解淘宝大秒杀系统设计,首次公开
  7. 浅谈web架构之架构设计
  8. 印度首颗 CPU 横空出世:软件开发已开动
  9. python编译安装没有c扩展_python – 为什么我在安装simplejson时得到“C扩展无法编译”?...
  10. pytorch自带网络_一篇长文学懂 pytorch