打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
最优模型选择中的交叉验证(Cross validation)方法


很多时候,大家会利用各种方法建立不同的统计模型,诸如普通的cox回归,利用Lasso方法建立的cox回归,或者稳健的cox回归;或者说利用不同的变量建立不同的模型,诸如模型一只考虑了三个因素、模型二考虑了四个因素,最后对上述模型选择(评价)的时候,或者是参数择优的时候,通常传统统计学方法中会用AIC,BIC、拟合优度-2logL,或者预测误差最小等准则来选择最优模型;而最新的文献中都会提到一种叫交叉验证(Cross validation)的方法,或者会用到一种将原始数据按照样本量分为两部分三分之二用来建模,三分之一用来验证的思路(临床上有医生称为内部验证),再或者利用多中心数据,一个中心数据用来建模,另外一个中心数据用来验证(临床上称为外部验证),这些都是什么?总结一下自己最近看的文献和书籍,在这里简单介绍下,仅供参考。

一、交叉验证的概念

交叉验证(Cross validation),有时亦称循环估计,是一种统计学上将数据样本切割成较小子集的实用方法。于是可以先在一个子集上做建模分析,而其它子集则用来做后续对此分析的效果评价及验证。一开始的子集被称为训练集(Train set)。而其它的子集则被称为验证集(Validationset)或测试集(Test set)。交叉验证是一种评估统计分析、机器学习算法对独立于训练数据的数据集的泛化(普遍适用性)能力(Generalize).例如下图文献中,原始数据集中449例观测,文献中将数据集分为了训练集(Primary Cohort)367例,验证集(Validation Cohort)82例。


二、交叉验证的原理及分类

假设利用原始数据可以建立 n 个统计模型,这 n 个模型的集合是M={M1,M2,…,Mn},比如我们想做回归,那么简单线性回归、logistic回归、随机森林、神经网络等模型都包含在M中。目标任务就是要从M中选择最好的模型。

(1)简单交叉验(hold-outcross validation)

假设训练集使用T来表示。如果想使用预测误差最小来度量模型的好坏,那么可以这样来选择模型:

1).使用T来训练每一个M,训练出参数(模型的系数)后,也就可以得到模型方程Fi。(比如,线性模型中得到系数ai后,也就得到了模型Fi(x)=aTx);

2).选择预测误差最小的模型。

遗憾的是这个算法不可行,比如我们需要拟合一些样本点,使用高阶的多项式回归肯定比线性回归错误率要小,偏差小,但是方差却很大,会过度拟合。因此,我们改进算法如下:

1).从全部的训练数据T中随机选择70%的样本作为训练集Ttrain,剩余的30%作为测试集Tcv。

2).在Ttrain上训练每一个M,得到模型Fi。

3).在Tcv上测试每一个Fi,得到相应的预测误差e。

4).选择具有最小预测误差的作为最佳模型。

这种方法称为hold-outcross validation或者称为简单交叉验证。

由于测试集是和训练集中是两个世界的,因此可以认为这里的预测误差接近于真实误差(generalizationerror)。这里测试集的比例一般占全部数据的1/4-1/3。30%是典型值。

还可以对模型作改进,当选出最佳的模型M后,再在全部数据T上做一次训练,显然训练数据越多,模型参数越准确。

简单交叉验证方法的弱点在于得到的最佳模型是在70%的训练数据上选出来的,不代表在全部训练数据上是最佳的。还有当训练数据本来就很少时,再分出测试集后,训练数据就太少了。其实严格意义来说Hold-OutMethod并不能算是CV,因为这种方法没有达到交叉的思想,由于是随机的将原始数据分组,所以最后验证集分类准确率的高低与原始数据的分组有很大的关系,所以这种方法得到的结果其实并不具有说服性.

(2)k-折叠交叉验证(k-fold cross validation)

进一步对简单交叉验证方法再做一次改进,如下:

1).将全部训练集T分成k个不相交的子集,假设T中的训练样例个数为m,那么每一个子集有m/k个训练样例,相应的子集称作{T1,T2,…, Tk}。

2).每次从模型集合M中拿出来一个Mi,然后在训练子集中选择出k-1个{T1,T2,Tj-1,Tj+1…,Tk}(也就是每次只留下一个Tj),使用这k-1个子集训练Mi后,得到假设函数Fij。最后使用剩下的一份Tj作测试,得到预测误差eij。

3).由于我们每次留下一个Tj(j从1到k),因此会得到k个预测误差,那么对于一个Mi,它的预测误差是这k个预测误差的平均。

4).选出平均经验错误率最小的Mi,然后使用全部的T再做一次训练,得到最后的模型Fi。

此方法称为k-fold cross validation(k-折叠交叉验证)说白了,这个方法就是将简单交叉验证的测试集改为1/k,每个模型训练k次,测试k次,预测误差为k次的平均。K一般大于等于2,实际操作时一般从3开始取,只有在原始数据集合数据量小的时候才会尝试取2.K-CV可以有效的避免过学习以及欠学习状态的发生,最后得到的结果也比较具有说服性.一般讲k取值为10。这样数据稀疏时基本上也能进行。显然,缺点就是训练和测试次数过多。

(3)留一交叉验证(leave one out cross validation)

极端情况下k-折叠交叉验证中k可以取值为m,意味着每次留一个样例做测试,这个称为leave-one-outcross validation(LOO-CV)。

如果设原始数据有N个样本,那么LOO-CV就是N-CV,即每个样本单独作为验证集,其余的N-1个样本作为训练集,所以LOO-CV会得到N个模型,用这N个模型最终的验证集的分类准确率的平均数作为此下LOO-CV分类器的性能指标.相比于前面的K-CV,LOO-CV有两个明显的优点:

1).每一回合中几乎所有的样本皆用于训练模型,因此最接近原始样本的分布,这样评估所得的结果比较可靠。

2).实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。

但LOO-CV的缺点则是计算成本高,因为需要建立的模型数量与原始数据样本数量相同,当原始数据样本数量相当多时,LOO-CV在实作上便有困难几乎就是不显示,除非每次训练分类器得到模型的速度很快,或是可以用并行化计算减少计算所需的时间。

最后,介绍一种与交叉验证有些不同但在各种文献中经常见到的模型选择方法,即自助法(Bootstrap法)。自助法的基本思想是从原始数据中用有放回的抽样方法来产生新的样本,例如我们有一大小为N的数据集,我们从中有放回的抽取N个样本,对N个样本进行建模,每个模型一个预测误差,取N各预测误差的平均值。这一方法在前期介绍C-index值是曾介绍过,C-index中的95%CI就是通过这一方法求出来的。这一方法建议在样本量小且单一的情况下使用。

三、注意事项

交叉验证使用中注意的事项:

1).训练集中样本数量要足够多,一般至少大于总样本数的50%。

2).训练集和测试集必须从完整的数据集中均匀取样。均匀取样的目的是希望减少训练集、测试集与原数据集之间的偏差。当样本数量足够多时,通过随机取样,便可以实现均匀取样的效果。(随机取样,可重复性差)


附:CrossValidation的R实现程序

library(ISLR)

set.seed(1)

str(Auto)

n=nrow(Auto)

train=sample(n,n/2)

test=(-train)

lm.fit=lm(mpg~horsepower, data=Auto, subset=train)

mean((Auto[test,'mpg']-predict(lm.fit, newdata=Auto[test,]))^2)

summary(lm.fit)

##

lm.fit2=lm(mpg~poly(horsepower,2), data=Auto, subset=train)

mean((Auto[test,'mpg']-predict(lm.fit2, newdata=Auto[test,]))^2)

summary(lm.fit2)

###

testError=matrix(NA,10,10)

for(seed in 123:132){

set.seed(seed)

train=sample(n,n/2)

test=(-train)

for(degree in 1:10){

lm.fit=lm(mpg~poly(horsepower,degree), data=Auto,subset=train)

testError[seed-122,degree]=mean((Auto[test,'mpg']-predict(lm.fit,newdata=Auto[test,]))^2)

}

}

range=range(testError)

plot(testError[1,],ylim=range,type='l',col=rainbow(10)[1],xlab='degree',ylab='the estimated test MSE')

for(seed in 2:10)points(testError[seed,],type='l',col=rainbow(10)[seed])


每种颜色代表一次全体样本的随机对半划分

###

##Leave-One-Out Cross-Validation

> cv.error=rep(0,5)

> t1=Sys.time()

> for(degree in 1:5){

+ glm.fit=glm(mpg ~ poly(horsepower,degree), data=Auto)

+ cv.error[degree]=cv.glm(Auto,glm.fit)$delta[1]

+ }

> Sys.time()-t1

Time difference of 45.18859 secs

> cv.error

[1] 24.23151 19.24821 19.33498 19.42443 19.03321

> plot(cv.error,type='b',xlab='degree')



## K-Fold Cross Validation

> set.seed(17)

> cv.error.10=rep(0,10)

> t1=Sys.time()

> for(degree in 1:10){

+ glm.fit=glm(mpg ~ poly(horsepower,degree), data=Auto)

+ cv.error.10[degree]=cv.glm(Auto,glm.fit,K=10)$delta[1]

+ }

> Sys.time()-t1

Time difference of 2.411138 secs

> cv.error.10

[1] 24.20520 19.18924 19.30662 19.33799 18.87911 19.02103 18.8960919.71201 18.95140

[10] 19.50196

> plot(cv.error.10,type='b',xlab='degree')



本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
交叉验证(Cross Validation)简介
深入浅出机器学习的基本原理与基础概念
数据挖掘入门指南!!!
评估分类器的性能:保持方法、交叉验证、自助法等
程序员学人工智能:如何设计数据科学算法的流程?
libsvm交叉验证与网格搜索教程
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服