R语言利用caret包比较ROC曲线的操作

时间:2021-05-20

说明

我们之前探讨了多种算法,每种算法都有优缺点,因而当我们针对具体问题去判断选择那种算法时,必须对不同的预测模型进行重做评估。

为了简化这个过程,我们使用caret包来生成并比较不同的模型与性能。

操作

加载对应的包与将训练控制算法设置为10折交叉验证,重复次数为3:

library(ROCR)library(e1071)library("pROC")library(caret)library("pROC")control = trainControl(method = "repaetedcv", number = 10, repeats =3, classProbs = TRUE, summaryFunction = twoClassSummary)

使用glm在训练数据集上训练一个分类器

glm.model = train(churn ~ ., data= trainset, method = "glm", metric = "ROC", trControl = control)

使用svm在训练数据集上训练一个分类器

svm.model = train(churn ~ ., data= trainset, method = "svmRadial", metric = "ROC", trControl = control)

使用rpart函数查看rpart在训练数据集上的运行情况

rpart.model = train(churn ~ ., data = trainset, method = "svmRadial", metric = "ROC", trControl = control)

使用不同的已经训练好的数据分类预测:

glm.probs = predict(glm.model,testset[,!names(testset) %in% c("churn")],type = "prob")svm.probs = predict(svm.model,testset[,!names(testset) %in% c("churn")],type = "prob")rpart.probs = predict(rpart.model,testset[,!names(testset) %in% c("churn")],type = "prob")

生成每个模型的ROC曲线,将它们绘制在一个图中:

glm.ROC = roc(response = testset[,c("churn")], predictor = glm.probs$yes, levels = levels(testset[,c("churn")]))plot(glm.ROC,type = "S",col = "red")svm.ROC = roc(response = testset[,c("churn")], predictor = svm.probs$yes, levels = levels(testset[,c("churn")]))plot(svm.ROC,add = TRUE,col = "green")rpart.ROC = roc(response = testset[,c("churn")], predictor = rpart.probs$yes, levels = levels(testset[,c("churn")]))plot(rpart.ROC,add = TRUE,col = "blue")

三种分类器的ROC曲线

说明

将不同的分类模型的ROC曲线绘制在同一个图中进行比较,设置训练过程的控制参数为重复三次的10折交叉验证,模型性能的评估参数为twoClassSummary,然后在使用glm,svm,rpart,三种不同的方法建立分类模型。

从图中可以看出,svm对训练集的预测结果(未调优)是三种分类算法里最好的。

补充:R语言利用caret包比较模型性能差异

说明

我们可以通过重采样的方法得对每一个匹配模型的统计信息,包括ROC曲线,灵敏度与特异度,然后基于这些统计信息来比较不同模型的性能差异。

操作

利用上节的信息,准备好glm分类模型,svm分类模型,rpart分类模型,并存放在glm.model,svm.model,rpart.model。

cv.values = resamples(list(glm = glm.model,svm =svm.model,rpart = rpart.model))> summary(cv.values)Call:summary.resamples(object = cv.values)Models: glm, svm, rpart Number of resamples: 30 ROC Min. 1st Qu. Median Mean 3rd Qu. Max. NA'sglm 0.7597790 0.7927740 0.8040455 0.8106454 0.8347961 0.8760824 0svm 0.8191998 0.8786439 0.8945208 0.8947360 0.9196775 0.9562556 0rpart 0.6064540 0.7150320 0.7608241 0.7556544 0.8086731 0.8554750 0Sens Min. 1st Qu. Median Mean 3rd Qu. Max. NA'sglm 0.08823529 0.1764706 0.2058824 0.2124930 0.2516807 0.3235294 0svm 0.44117647 0.5294118 0.5882353 0.5956863 0.6470588 0.7941176 0rpart 0.20000000 0.4117647 0.4705882 0.4787955 0.5514706 0.7352941 0Spec Min. 1st Qu. Median Mean 3rd Qu. Max. NA'sglm 0.9393939 0.9645119 0.9721581 0.9702721 0.9796954 0.9898477 0svm 0.9494949 0.9695431 0.9771574 0.9755004 0.9847716 0.9898990 0rpart 0.9492386 0.9746193 0.9796954 0.9780359 0.9848485 1.0000000 0

使用dotplot函数绘制重采样在ROC曲线度量中的结果:

dotplot(cv.values,metric = "ROC")

使用箱线图绘制重采样结果:

bwplot(cv.values,layout=c(3,1))

重采样结果箱线图

说明

我们使用resample函数生成各个模型的统计信息,再调用summary函数输出三个模型在ROC、灵敏度及特异性上的统计信息。

使用dotplot方法处理重采样结果来观测不同模型ROC差异,最后,采用箱线图在同一张图上对ROC、灵敏度及特异方面的差别进行比较。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持。如有错误或未考虑完全的地方,望不吝赐教。

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章