I have prepared this post as documentation for a speech I will give on November 12th with my colleagues of Grupo-R madRid. In our previous meeting Jesús Herranz gave us a good introduction on survival models, but he reserved the best stuff for his workshop on random forests for survival, which happened in our recent VII R-hispano users group congress -maybe the best event about R in Spain.
I had recently prepared an introduction on survival models focused on business settings -as a contribution to the applicability of these wonderful models outside the more frequent biomedical field- as my first techie blog post -in spanish (!). In the meantime I also found these absolutely gReat tutorials on survival for business by Dayne Batten, and now the idea is to quickly apply part of the lessons learnt at Jesus’ workshop to the churn dataset.
Let’s go into the code. It should work on any R environment, but please be very careful with some hints for installing the above libraries. If you find any problem please write a comment!
You can download all this content and its code from an Rmarkdown file in my github.
Setup your R environment
Be sure to change your working directory:
Load (and install if necessary) the required libraries (but see below special requirement about the randomForestSRC library).
list.of.packages <- c("survival", "caret", "glmnet", "rms", "doParallel", "risksetROC") new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])] if(length(new.packages)) install.packages(new.packages) library(survival, quietly = TRUE) library(caret, quietly = TRUE) library(glmnet, quietly = TRUE) library(rms, quietly = TRUE) library(risksetROC, quietly = TRUE) library(doParallel, quietly = TRUE) registerDoParallel(detectCores() - 1 ) ## registerDoMC( detectCores()-1 ) in Linux detectCores() options(rf.cores = detectCores() - 1, mc.cores = detectCores() - 1) ## Cores for parallel processing
Read more about doParallel.
By default, doParallel uses multicore functionality on Unix-like systems and snow functionality on Windows. Note that the multicore functionality only runs tasks on a single computer, not a cluster of computers.
randomForestSRC package allows parallelization but the library binaries are different for Windows and Linux, so you must go to Hemant Ishwaran’s rfsrc page and download the zip file and install it as source.
My system is a windows 7 machine, so I am using one version of this zip. Use the appropriate one for your platform. Please verify it loads in your system.
install.packages("http://www.ccs.miami.edu/~hishwaran/rfsrc/randomForestSRC_1.6.0.zip", repos = NULL, type = "source") library(randomForestSRC)
And only after this you can install ggRandomForests, with very useful plotting functions for the random forests objects created with randomForestSRC.
install.packages("ggRandomForests", repos = 'http://cran.us.r-project.org') #since you had source before library(ggRandomForests)
Load and explore the data about churn
I found these Churn data (artificial based on claims similar to real world) suggested at stackoverflow question. Data are part of UCI machine learning training sets, also more quickly found at http://www.sgi.com/tech/mlc/db/.
nm <- read.csv("http://www.sgi.com/tech/mlc/db/churn.names", skip = 4, colClasses = c("character", "NULL"), header = FALSE, sep = ":")[] dat <- read.csv("http://www.sgi.com/tech/mlc/db/churn.data", header = FALSE, col.names = c(nm, "Churn"), colClasses = c("factor", "numeric", "factor", "character", rep("factor", 2), rep("numeric", 14), "factor")) # test data test <- read.csv("http://www.sgi.com/tech/mlc/db/churn.test", header = FALSE, col.names = c(nm, "Churn"), colClasses = c("factor", "numeric", "factor", "character", rep("factor", 2), rep("numeric", 14), "factor"))
This is a quick exploration of training dataset. You have 3333 unique customer id’s (phone numbers), account.length is the age (time dimension), which seems to be months, but I am not totally sure, and you have 15% drop-outs, which is a quite high churn rate but then consider we have >10 years span.
dim(dat) summary(dat) length(unique(dat$phone.number)) hist(dat$account.length) table(dat$Churn)/nrow(dat)*100
##  3333 21 ## state account.length area.code phone.number ## WV : 106 Min. : 1 408: 838 Length:3333 ## MN : 84 1st Qu.: 74 415:1655 Class :character ## NY : 83 Median :101 510: 840 Mode :character ## AL : 80 Mean :101 ## OH : 78 3rd Qu.:127 ## OR : 78 Max. :243 ## (Other):2824 ## international.plan voice.mail.plan number.vmail.messages ## no :3010 no :2411 Min. : 0.0 ## yes: 323 yes: 922 1st Qu.: 0.0 ## Median : 0.0 ## Mean : 8.1 ## 3rd Qu.:20.0 ## Max. :51.0 ## ## total.day.minutes total.day.calls total.day.charge total.eve.minutes ## Min. : 0 Min. : 0 Min. : 0.0 Min. : 0 ## 1st Qu.:144 1st Qu.: 87 1st Qu.:24.4 1st Qu.:167 ## Median :179 Median :101 Median :30.5 Median :201 ## Mean :180 Mean :100 Mean :30.6 Mean :201 ## 3rd Qu.:216 3rd Qu.:114 3rd Qu.:36.8 3rd Qu.:235 ## Max. :351 Max. :165 Max. :59.6 Max. :364 ## ## total.eve.calls total.eve.charge total.night.minutes total.night.calls ## Min. : 0 Min. : 0.0 Min. : 23.2 Min. : 33 ## 1st Qu.: 87 1st Qu.:14.2 1st Qu.:167.0 1st Qu.: 87 ## Median :100 Median :17.1 Median :201.2 Median :100 ## Mean :100 Mean :17.1 Mean :200.9 Mean :100 ## 3rd Qu.:114 3rd Qu.:20.0 3rd Qu.:235.3 3rd Qu.:113 ## Max. :170 Max. :30.9 Max. :395.0 Max. :175 ## ## total.night.charge total.intl.minutes total.intl.calls total.intl.charge ## Min. : 1.04 Min. : 0.0 Min. : 0.00 Min. :0.00 ## 1st Qu.: 7.52 1st Qu.: 8.5 1st Qu.: 3.00 1st Qu.:2.30 ## Median : 9.05 Median :10.3 Median : 4.00 Median :2.78 ## Mean : 9.04 Mean :10.2 Mean : 4.48 Mean :2.77 ## 3rd Qu.:10.59 3rd Qu.:12.1 3rd Qu.: 6.00 3rd Qu.:3.27 ## Max. :17.77 Max. :20.0 Max. :20.00 Max. :5.40 ## ## number.customer.service.calls Churn ## Min. :0.00 False.:2850 ## 1st Qu.:1.00 True. : 483 ## Median :1.00 ## Mean :1.56 ## 3rd Qu.:2.00 ## Max. :9.00 ## ##  3333
And about the test set. You have exactly 1667 rows, exactly half of the training set.
Random Forests for Survival
Random Forests (RF) is a machine learning technique which builds a large number of decision trees that:
- are based on bootstrap samples. Each tree is based on a random sample with replacement of all observations.
- each tree division is based on a random sample of predictors.
- There is no prunning, trees are as long as possible, they are not “cut”.
For building each RF tree a part of the observations is not used (37% aprox.). This is called out-of-bag -OOB- sample and is used for a honest estimate of the model predictive capability.
Random Survival Forest (RSF) is a class of survival prediction models, those that use data on the life history of subjects (the response) and their characteristics (the predictor variables). In this case, it extends the RF algorithm for a target which is not a class, or a number, but a survival curve. The library is actually so clever that given a particular target it automatically selects the relevant algorith. There are four families of random forests:
regression forests for continuous responses
classification forests for factor responses
Survival forests for right-censored survival settings
competing risk survival forests for competing risk scenarios
RF is now a standard to effectively analyze a large number of variables, of many different types, with no previous variable selection process. It is not parametric, and in particular for survival target it does not assume the proportional risks assumption.
rfsrc requires all data to be either numeric or factors. So you must filter out character, date or other types of variables. Survival object requires a numeric (0/1) target, and in my R environment I have had problems to input factors into the analysis, so I quickly drop those variables (in previous analyses they were not any relevant). And we quickly -and dirtly- convert into dummies (0/1) two relevant factors and the target (Churn).
dat$phone.number <- NULL dat$state <- NULL dat$area.code <- NULL dat$international.plan <- as.numeric(dat$international.plan) - 1 dat$voice.mail.plan <- as.numeric(dat$voice.mail.plan) - 1 dat$Churn <- as.numeric(dat$Churn) - 1 summary(dat)
You use rfsrc() to build a RF model with the following parameters:
formula : response variable and predictors
data : data frame containing data
ntree: total number of trees
mtry: number of variables entering in each division as candidates. By default sqrt(p) in classification and survival and p/3 in regression.
nodesize : minimum number of observations in a terminal node (survival, 3).
nsplit : number of random points to explore in continous predictors.
importance = T : prints out variable importance ranking (if not use importance=â€œnoneâ€).
proximity = T : to compute this metric.
Handling factors is not that easy. See the “Allowable data types and issues related to factors” part in rfsrc documentation.
Let’s try a simple RF model with 50 trees and nsplit 2.
out.rsf.1 <- rfsrc(Surv(account.length, Churn) ~ . , data = dat, ntree = 50, nsplit = 2) out.rsf.1
## Sample size: 3333 ## Number of deaths: 483 ## Number of trees: 50 ## Minimum terminal node size: 3 ## Average no. of terminal nodes: 352.3 ## No. of variables tried at each split: 4 ## Total no. of variables: 16 ## Analysis: RSF ## Family: surv ## Splitting rule: logrank *random* ## Number of random split points: 2 ## Error rate: 13.75%
Computation is intensive (though we have requeste a very simple model and we do not have too many data). rsfrc() permutesall values of all variables in all trees. However, at least for me, the parallelization implementation works flawlessly and gives output very quickly.
The $importance object contains variables importance, in same order as in input dataset. We sort it out to show a ranking and use ggRandomForestses to plot this object. This library uses ggplot2, so you can easily retouch these plots.
imp.rsf.1 <- sort(out.rsf.1$importance, decreasing = T) imp.rsf.1 plot(gg_vimp(out.rsf.1))
## international.plan total.day.charge ## 0.1192592 0.0820343 ## total.day.minutes number.customer.service.calls ## 0.0696993 0.0639611 ## voice.mail.plan total.eve.minutes ## 0.0288920 0.0222369 ## total.eve.charge number.vmail.messages ## 0.0214636 0.0106841 ## total.intl.minutes total.intl.calls ## 0.0106587 0.0093293 ## total.intl.charge total.night.charge ## 0.0086888 0.0046463 ## total.night.minutes total.eve.calls ## 0.0026219 0.0009676 ## total.day.calls total.night.calls ## 0.0006847 -0.0006018
This is the variable importance plot
Predictive ability of RSF
To compute a numeric prediction per case, we sum the risk output estimate along all times. This is equivalent to a risk score, so that higher values correspond to observations with higher observed risk, lower survival. These predictions can be based on all trees or only on the OOB sample.
In RSF, error rate is defined as = 1 – C-index. C-index means that the higher the survival, the lower the risk. Let’s check C-index.
length(out.rsf.1$predicted.oob) ##  3333 head(out.rsf.1$predicted.oob) ##  24.412 6.364 9.577 58.879 52.422 31.528 sum.chf.oob = apply(out.rsf.1$chf.oob , 1, sum) head(sum.chf.oob) ##  24.412 6.364 9.577 58.879 52.422 31.528 rcorr.cens(out.rsf.1$predicted.oob, Surv(dat$account.length, dat$Churn))["C Index"] ## C Index ## 0.1371 err.rate.rsf = out.rsf.1$err.rate[ out.rsf.1$ntree ] err.rate.rsf ##  0.1375 rcorr.cens(-out.rsf.1$predicted.oob, Surv(dat$account.length, dat$Churn))["C Index"] ## C Index ## 0.8629
Towards an optimal RFS
An essential parameter is the number of trees. To compute the best number of trees you use importance=â€œnoneâ€ so that you do not use unnecessary computing, and set a sufficiently high number of trees. This depends as you could see by default values, on the number of predictor variables. In our example, with just a few variables, a value of a few hundreds looks enough. We then use gg_error() of ggRandomForests to plot the results across the number of trees. You choose the point (minimum number) where plot converges into a minimum. If it does not converge, try with a higher number of trees.
out.rsf.3 <- rfsrc( Surv(account.length, Churn) ~ . , data = dat, ntree = 200, importance = "none", nsplit = 1) out.rsf.3 plot(gg_error(out.rsf.3))
## Sample size: 3333 ## Number of deaths: 483 ## Number of trees: 200 ## Minimum terminal node size: 3 ## Average no. of terminal nodes: 365.4 ## No. of variables tried at each split: 4 ## Total no. of variables: 16 ## Analysis: RSF ## Family: surv ## Splitting rule: logrank *random* ## Number of random split points: 1 ## Error rate: 13.97%
Predictive ability applied to test data: C Index
Let’s do as usual: apply our model to the test data and check for predictive ability.
First of all, remember to make same mods as we did to the training to test set!!!
test$phone.number <- NULL test$state <- NULL test$area.code <- NULL test$international.plan <- as.numeric(test$international.plan) - 1 test$voice.mail.plan <- as.numeric(test$voice.mail.plan) - 1 test$Churn <- as.numeric(test$Churn) - 1 summary(test)
We apply the computed model to the test set using, as usual, predict. We then check the C_Index as before.
pred.test.fin = predict( out.rsf.3, newdata = test, importance = "none" ) rcorr.cens(-pred.test.fin$predicted , Surv(test$account.length, test$Churn))["C Index"]
## C Index ## 0.8596
Predictive ability. ROC generalizations
risksetROC library provides functions to compute the equivalent to a ROC curve and its associated Area Under Curve (AUC) in a time-dependent context.
This is precisely the greatest advantage: you can compute predictive ability at specific time points or intervals. Let’s see predictive ability of OOB samples (in training set) at median time. We must be careful about method as, depending on assumptions, this can be different (see documentation), but let’s assume we meet Cox’s proportional hazards assumption.
w.ROC = risksetROC(Stime = dat$account.length, status = dat$Churn, marker = out.rsf.3$predicted.oob, predict.time = median(dat$account.length), method = "Cox", main = paste("OOB Survival ROC Curve at t=", median(dat$account.length)), lwd = 3, col = "red" ) w.ROC$AUC
##  0.8095
For risksetROC to compute AUC along an interval you use risksetAUC using tmax (maximum time). You get a very nice plot of AUC across time. This is still OOB samples.
w.ROC = risksetAUC(Stime = dat$account.length, status = dat$Churn, marker = out.rsf.3$predicted.oob, tmax = 250)
Let’s do the same for test data.
w.ROC = risksetAUC(Stime = test$account.length, status = test$Churn, marker = pred.test.fin$predicted, tmax = 220, method = "Cox")
And with a plot, at good local maximum prediction time, 190.
w.ROC = risksetROC(Stime = test$account.length, status = test$Churn, marker = pred.test.fin$predicted, predict.time = 190, method = "Cox", main = paste("OOB Survival ROC Curve at t=190"), lwd = 3, col = "red" ) w.ROC$AUC
w.ROC$AUC ##  0.8269
And with a plot, at maybe best prediction time, 220 (?????).
w.ROC = risksetROC(Stime = test$account.length, status = test$Churn, marker = pred.test.fin$predicted, predict.time = 220, method = "Cox", main = paste("OOB Survival ROC Curve at t=220"), lwd = 3, col = "red" )
w.ROC$AUC ##  0.8138
- randomForestSRC is a fabulous library for all tasks related to Random Forests computing, and it has a wondeRful implementation for survival targets.
- I have found some issues for handling factors that I shall explore, so that you can input factors directly into the algorithm (without converting into dummies).
- RSF gives important advantages over traditional (classification) RF when you have survival (time dependent) targets. In particular you can check predictive capability and its changes across time, and many other useful results for obtaining knowledge of your analysis. All this maintaing the predictive ability of the traditional classification models.
I hope this encourages you to try these models and these libraries for your analyses. Feel free to comment your results, or problems if you found them!