作者:量化小白一枚,上财研究生在读,偏向数据分析与量化投资

个人公众号:量化小白上分记

接上一篇《R语言模拟:Bias-Variance trade-off》,本文通过模拟分析算法的泛化误差、偏差、方差和噪声之间的关系,是《element statistical learning》第七章的一个案例。

上一篇通过模拟给出了在均方误差度量下,测试集上存在的偏差方差Trade-Off的现象,随着模型复杂度(变量个数)增加,训练集上的误差不断减小,最终最终导致过拟合,而测试集的误差则先减小后增大。

模拟方法说明

本文通过对泛化误差的分解来说明训练集误差变化的原因,我们做如下模拟实验:

样本1::训练集和测试集均为20个自变量,80个样本,自变量服从[0,1]均匀分布,因变量定义为:

Y = ifelse(X1>1/2,1,0)

样本2 : 训练集和测试集均为20个自变量,80个样本,自变量服从[0,1]均匀分布,因变量定义为:

Y = ifelse(X1+X2+...+X10>5,1,0)

通过两类模型、两种误差度量方式共四种方法进行建模,分析误差,模型为knnbest subset linear model

knn根据距离样本最近的k个样本的Y值预测样本的Y值,knn模型用于样本1,R语言中可通过函数knnreg实现。

best subset linear model 对于输入的样本,获取最优的自变量组合建立线性模型进行预测,best subset model用于样本2,R语言中可通过函数regsubsets实现。

误差度量分为均方误差(squared error)和0-1误差(0-1 Loss)两种,均方误差可以视为回归模型(regression),0-1误差可以视为分类模型(classification)。

结果说明

每种方法模拟100次,在每个模型中计算偏差、方差和预测误差并作图分析结果,最终得到结果如下:

其中,红色线表示预测误差,蓝色线表示方差,绿色线表示偏差平方,对比书上的结果

结果分析:

  1. 从数值上看,0-1 Loss 和Squared  error度量的模型具有不同特征,0-1 Loss满足预测误差 = 方差 +偏差平方的关系式,Squared error不满足这一关系;

  2. 方差都是随着模型中包含变量个数增加而减小,偏差的变化非线性。

代码

语言:r

knn model

  1# bais variance trade-off  regression  2  3# knn   4  5library(caret)  6  7# get bais variance  8# k:knn中的k值或best subset中的k值  9# num:模拟次数 10# sigma:随机误差的标准差 11# test_id 用于计算偏差误差的训练集样本编号,1-80中任一整数 12# regtype:knn或best sub 13# seeds:随机数种子 14# 返回方差偏差误差等值 15 16getError <- function(k,num,modeltype,seeds,n_test){ 17  set.seed(seeds) 18 19 20  testset <- as.data.frame(matrix(runif(n_test*21,0,1),n_test)) 21 22  Allfx_hat <- matrix(0,n_test,num) 23  Ally <- matrix(0,n_test,num) 24  Allfx <- matrix(0,n_test,num) 25 26  # 模拟 num次  27 28 29 30  for (i in 1:num){ 31    trainset <- as.data.frame(matrix(runif(80*21,0,1),80)) 32 33 34    fx_train <- ifelse(trainset[,1]>0.5,1,0) 35    trainset[,21] <- fx_train 36 37    fx_test <- ifelse(testset[,1]>0.5,1,0) 38    testset[,21] <- fx_test  39 40 41    # knn model 42    knnmodel <- knnreg(trainset[,1:20],trainset[,21],k = k) 43    probs <- predict(knnmodel, newdata = testset[,1:20]) 44 45 46    Allfx_hat[,i] <- probs 47    Ally[,i] <- testset[,21] 48    Allfx[,i] <- fx_test 49 50 51 52  }  53  # 计算方差、偏差等 54 55  # irreducible <- sigma^2 56 57  irreducible  <- mean(apply( Allfx - Ally  ,1,var)) 58  SquareBais  <- mean(apply((Allfx_hat - Allfx)^2,1,mean)) 59  Variance <- mean(apply(Allfx_hat,1,var)) 60 61  # 回归或分类两种情况 62  if (modeltype == 'reg'){ 63 64    PredictError  <- irreducible + SquareBais + Variance  65 66  }else{ 67 68    PredictError  <- mean(ifelse(Allfx_hat>=0.5,1,0)!=Allfx) 69  } 70 71 72 73  result <- data.frame(k,irreducible,SquareBais,Variance,PredictError) 74 75  return(result) 76} 77 78# ----------------   plot square error  knn ---------------------------- 79 80 81 82 83# k:knn中的k值或best subset中的k值 84# num:模拟次数 85# test_id 用于计算偏差误差的训练集样本编号,1-80中任一整数 86# regtype:knn或best sub 87# seeds:随机数种子 88 89n_test <- 100 90modeltype <- 'reg' 91num <- 100 92 93seeds <- 1 94 95result <- getError(2,num,modeltype,seeds,n_test) 96result <- rbind(result,getError(5,num,modeltype,seeds,n_test)) 97result <- rbind(result,getError(7,num,modeltype,seeds,n_test)) 98for (i in seq(10,50,10)){ 99  result <- rbind(result,getError(i,num,modeltype,seeds,n_test))100101}102103104png(file = "k-NN - Regression_large_testset.png")105106plot(-result$k,result$PredictError,type = 'o',col = 'red',107     xlim = c(-50,0),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2)108par(new = T)109plot(-result$k,result$SquareBais,type = 'o',col = 'green',110     xlim = c(-50,0),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2)111par(new = T) 112plot(-result$k,result$Variance,type = 'o',col = 'blue',113     xlim = c(-50,0),ylim = c(0,0.4),xlab = 'Number of Neighbors k', ylab ='', lwd = 2,114     main = 'k-NN - Regression')115dev.off()116117# ----------------------  plot 0-1 loss knn -------------------------118modeltype <- 'classification'119num <- 100120n_test <- 100121seeds <- 1122123result <- getError(2,num,modeltype,seeds,n_test)124result <- rbind(result,getError(5,num,modeltype,seeds,n_test))125result <- rbind(result,getError(7,num,modeltype,seeds,n_test))126for (i in seq(10,50,10)){127  result <- rbind(result,getError(i,num,modeltype,seeds,n_test))128129}130131132png(file = "k-NN - Classification_large_testset.png")133134plot(-result$k,result$PredictError,type = 'o',col = 'red',135     xlim = c(-50,0),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2)136par(new = T)137plot(-result$k,result$SquareBais,type = 'o',col = 'green',138     xlim = c(-50,0),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2)139par(new = T) 140plot(-result$k,result$Variance,type = 'o',col = 'blue',141     xlim = c(-50,0),ylim = c(0,0.4),xlab = 'Number of Neighbors k', ylab ='', lwd = 2,142     main = 'k-NN - Classification')143dev.off()

best subset model 

  1library(leaps)   2lm.BestSubSet<- function(trainset,k){  3  lm.sub <- regsubsets(V21~.,trainset,nbest =1,nvmax = 20)  4  summary(lm.sub)  5  coef_lm <- coef(lm.sub,k)  6  strings_coef_lm <- coef_lm  7  x <- paste(names(coef_lm)[2:length(coef_lm)], collapse ='+')  8  formulas <- as.formula(paste('V21~',x,collapse=''))  9  return(formulas) 10} 11 12getError <- function(k,num,modeltype,seeds,n_test){ 13  set.seed(seeds) 14  testset <- as.data.frame(matrix(runif(n_test*21,0,1),n_test)) 15 16  Allfx_hat <- matrix(0,n_test,num) 17  Ally <- matrix(0,n_test,num) 18  Allfx <- matrix(0,n_test,num) 19 20 21  # 模拟 num次 22 23 24 25  for (i in 1:num){ 26    trainset <- as.data.frame(matrix(runif(80*21,0,1),80)) 27    fx_train <- ifelse(trainset[,1] +trainset[,2] +trainset[,3] +trainset[,4] +trainset[,5]+ 28                         trainset[,6] +trainset[,7] +trainset[,8] +trainset[,9] +trainset[,10]>5,1,0) 29 30    trainset[,21] <- fx_train 31 32    fx_test <- ifelse(testset[,1] +testset[,2] +testset[,3] +testset[,4] +testset[,5]+ 33                        testset[,6] +testset[,7] +testset[,8] +testset[,9] +testset[,10]>5,1,0) 34 35    testset[,21] <- fx_test  36 37 38    # best subset 39    lm.sub <- lm(formula = lm.BestSubSet(trainset,k),trainset) 40    probs <- predict(lm.sub,testset[,1:20], type = 'response') 41 42 43    Allfx_hat[,i] <- probs 44    Ally[,i] <- testset[,21] 45    Allfx[,i] <- fx_test 46 47  }  48  # 计算方差、偏差等 49 50  # irreducible <- sigma^2 51 52  irreducible  <- mean(apply( Allfx - Ally  ,1,var)) 53  SquareBais  <- mean(apply((Allfx_hat - Allfx)^2,1,mean)) 54  Variance <- mean(apply(Allfx_hat,1,var)) 55 56  # 回归或分类两种情况 57  if (modeltype == 'reg'){ 58    PredictError <- irreducible + SquareBais + Variance  59  }else{ 60    PredictError <- mean(ifelse(Allfx_hat>=0.5,1,0)!=Allfx) 61  } 62  result <- data.frame(k,irreducible,SquareBais,Variance,PredictError) 63  return(result) 64} 65 66 67 68# ----------------   plot square error Best Subset Regression ---------------------------- 69 70 71modeltype <- 'reg' 72num <- 100 73n_test <- 1000 74 75seeds <- 4 76all_p <- seq(2,20,3) 77result <- getError(1,num,modeltype,seeds,n_test) 78for (i in all_p){ 79  result <- rbind(result,getError(i,num,modeltype,seeds,n_test)) 80 81} 82 83png(file = "Linear Model - Regression_large_testset.png") 84 85plot(result$k,result$PredictError,type = 'o',col = 'red', 86     xlim = c(0,20),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2) 87par(new = T) 88plot(result$k,result$SquareBais,type = 'o',col = 'green', 89     xlim = c(0,20),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2) 90par(new = T)  91plot(result$k,result$Variance,type = 'o',col = 'blue', 92     xlim = c(0,20),ylim = c(0,0.4),xlab = 'Subset Size p', ylab ='', lwd = 2, 93     main = 'Linear Model - Regression') 94dev.off() 95 96# ----------------------  plot 0-1 loss Best Subset Classification ------------------------- 97 98modeltype <- 'classification' 99num <- 100100n_test <- 1000101seeds <- 4102103104all_p <- seq(2,20,3)105result <- getError(1,num,modeltype,seeds,n_test)106for (i in all_p){107  result <- rbind(result,getError(i,num,modeltype,seeds,n_test))108109}110111png(file = "Linear Model - Classification_large_testset.png")112113114plot(result$k,result$PredictError,type = 'o',col = 'red',115     xlim = c(0,20),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2)116par(new = T)117plot(result$k,result$SquareBais,type = 'o',col = 'green',118     xlim = c(0,20),ylim = c(0,0.4),xlab = '', ylab ='', lwd = 2)119par(new = T) 120plot(result$k,result$Variance,type = 'o',col = 'blue',121     xlim = c(0,20),ylim = c(0,0.4),xlab = 'Subset Size p', ylab ='', lwd = 2,122     main = 'Linear Model - Classification')123# 124dev.off()

参考文献

1. Ruppert D. The Elements of Statistical Learning: Data Mining, Inference, and Prediction[J]. Journal of the Royal Statistical Society, 2010, 99(466):567-567.


公众号后台回复关键字即可学习

回复 爬虫            爬虫三大案例实战  
回复 Python       1小时破冰入门

回复 数据挖掘     R语言入门及数据挖掘
回复 人工智能     三个月入门人工智能
回复 数据分析师  数据分析师成长之路 
回复 机器学习      机器学习的商业应用
回复 数据科学      数据科学实战
回复 常用算法      常用数据挖掘算法

R语言模拟:Bias Variance Decomposition相关推荐

  1. R语言模拟疫情传播-gganimate包

    本文用gganimate包展示模拟疫情数据 本文篇幅较长,分为以下几个部分: 前言 效果展示 小结 附录:代码 前言 前文<R语言模拟疫情传播-RVirusBroadcast>已经介绍了一 ...

  2. R语言-模拟产生统计专业学生的成绩

    现在Mayuyu会以一个例子来说明R语言在统计学中的应用.模拟一个高中学生语数外三科的成绩单. 首先认识两个重要的函数,source()和print(),source()函数是用来运行R脚本的,一个R ...

  3. Bias - Variance Decomposition

    偏差-方差分解定理 解释了训练的数据和调控因子lamda(惩罚项里的)的作用 因为机器学习的真实目标是期望风险最小化(Expected Generalization Loss),其可以分解为三个部分 ...

  4. R语言中文社区2018年终文章整理(作者篇)

    欢迎关注天善智能,我们是专注于商业智能BI,人工智能AI,大数据分析与挖掘领域的垂直社区,学习,问答.求职一站式搞定! 对商业智能BI.大数据分析挖掘.机器学习,python,R等数据领域感兴趣的同学 ...

  5. r语言数据变量分段_R数据分析:用R语言做meta分析

    这里以我的一篇meta分析为例,详细描述meta分析的一般步骤,该例子实现的是效应量β的合并 R包:metafor或meta包,第一个例子以metafor包为例. 1.准备数据集 2.异质性检验 in ...

  6. 语言模拟蒲丰问题_R语言小数定律的保险业应用:泊松分布模拟索赔次数

    原文链接: 拓端数据科技 / Welcome to tecdat​tecdat.cn 在保险业中,由于分散投资,通常会在合法的大型投资组合中提及大数定律.在一定时期内,损失"可预测" ...

  7. 多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

  8. 多元线性回归分析c语言,多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

  9. R语言数据统计分析的基本函数

    转载于http://blog.csdn.net/yujunbeta/article/details/8547171 在面对大规模数据时,对数据预处理,获取基本信息是十分必要的.今天分享的就是数据预处理 ...

  10. R语言与数据的预处理

    在面对大规模数据时,对数据预处理,获取基本信息是十分必要的.今天分享的就是数据预处理的一些东西. 一.获取重要数据 在导入大规模数据时,我们通常需要知道数据中的关键内容:最值,均值,离差,分位数,原点 ...

最新文章

  1. 将用户添加至sudoers列表
  2. struct2 开发环境搭建 问题
  3. POS开发问题 - 多个弹出框的实现
  4. 实现商城类APP的筛选项效果
  5. C#开发WPF/Silverlight动画及游戏系列教程(Game Tutorial):(一)让物体动起来①
  6. 在Ubuntu中为root用户启用界面登录
  7. 用C语言调用.bat批处理命令
  8. VNCTF2021[WEB]
  9. 《东周列国志》第四十九回 公子鲍厚施买国 齐懿公竹池遇变
  10. MAC系统/虚拟机中的chm打不开
  11. 6502精品仿真软件(联锁)
  12. web端前端自定义提示语信息
  13. MySQL入门系列:视图
  14. AVL树简单实现及原理
  15. 【BC260Y】 AT指令接入移动oneNet平台流程
  16. oracle stdevp函数,ORACLE和SQL语法区别归纳整理.doc
  17. oracle漏洞pdf,Oracle DBA手记 4 数据安全警示录 pdf完整扫描版版
  18. 7-6 最热门的职业
  19. C++_MFC读视频文件
  20. 我见众生皆无意,唯有见你动了情(表白日记分享篇)

热门文章

  1. 真分布式SolrCloud+Zookeeper+tomcat搭建、索引Mysql数据库、IK中文分词器配置以及web项目中solr的应用(1)
  2. CentOS 7.2下安装lamp环境
  3. 新手理解Navigator的教程
  4. 入职 6 个月,被裁员。。。
  5. MySQL 8.0 新特性:哈希连接(Hash Join)
  6. 如果计算机是中国人发明的,那编程代码很可能就应该这样写!
  7. 到底工资要多少合适?
  8. android动态注册广播权限,记动态注册广播权限问题
  9. jmeter constant timer 如何添加_性能测试-Jmeter——软件测试圈-软件测试文章
  10. 每日算法之三十五:Wildcard Matching