Cross Validation in R with Example

What Does Cross-Validation Mean?

Cross-validation is a statistical approach for determining how well the results of a statistical investigation generalize to a different data set.

Cross-validation is commonly employed in situations where the goal is prediction and the accuracy of a predictive model’s performance must be estimated.

We explored different stepwise regressions in a previous article and came up with different models, now let’s see how cross-validation can help us choose the best model.

Which model is the most accurate at forecasting?

To begin, we need to load our dataset:

library(purrr)
library(dplyr)
head(mtcars)
                   mpg cyl disp  hp drat    wt  qsec vs am gear carb
Mazda RX4         21.0   6  160 110 3.90 2.620 16.46  0  1    4    4
Mazda RX4 Wag     21.0   6  160 110 3.90 2.875 17.02  0  1    4    4
Datsun 710        22.8   4  108  93 3.85 2.320 18.61  1  1    4    1
Hornet 4 Drive    21.4   6  258 110 3.08 3.215 19.44  1  0    3    1
Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2
Valiant           18.1   6  225 105 2.76 3.460 20.22  1  0    3    1

There are several ways to accomplish this, but we’ll utilize the modelr package to assist us.

To begin, we divided our data into two categories:

KNN Algorithm Machine Learning » Classification & Regression »

K Fold Cross-Validation in R

library(modelr)
cv  <- crossv_kfold(mtcars, k = 5)
cv
train                test                .id  
  <named list>         <named list>        <chr>
1 <resample [25 x 11]> <resample [7 x 11]> 1    
2 <resample [25 x 11]> <resample [7 x 11]> 2    
3 <resample [26 x 11]> <resample [6 x 11]> 3    
4 <resample [26 x 11]> <resample [6 x 11]> 4    
5 <resample [26 x 11]> <resample [6 x 11]> 5    

Our data has been divided into five sets, each with a training set and a test set.

For each training set, we now use map to fit a model. In actuality, our three models will be fitted separately.

Decision Trees in R » Classification & Regression »

Model Fitting

models1  <- map(cv$train, ~lm(mpg ~ wt + cyl + hp, data = .))
models2  <- map(cv$train, ~lm(mpg ~ wt + qsec + am, data = .))
models3  <- map(cv$train, ~lm(mpg ~ wt + qsec + hp, data = .))

Now it’s time to make some predictions. To accomplish this, I created a tiny function that takes the models and test data and returns the predictions. It’s worth noting that I use as.data.frame to get the data ().

get_pred  <- function(model, test_data){
  data  <- as.data.frame(test_data)
  pred  <- add_predictions(data, model)
  return(pred)
}
pred1  <- map2_df(models1, cv$test, get_pred, .id = "Run")
pred2  <- map2_df(models2, cv$test, get_pred, .id = "Run")
pred3  <- map2_df(models3, cv$test, get_pred, .id = "Run")

Now we will calculate the MSE for each group:

datatable editor-DT package in R » Shiny, R Markdown & R »

MSE1  <- pred1 %>% group_by(Run) %>%
  summarise(MSE = mean( (mpg - pred)^2))
MSE1
Run     MSE
  <chr> <dbl>
1 1      7.36
2 2      1.27
3 3      5.31
4 4      8.84
5 5     13.8 
MSE2  <- pred2 %>% group_by(Run) %>%
  summarise(MSE = mean( (mpg - pred)^2))
MSE2
 Run     MSE
  <chr> <dbl>
1 1      6.45
2 2      2.27
3 3      7.71
4 4      9.56
5 5     15.4 
MSE3  <- pred3 %>% group_by(Run) %>%
  summarise(MSE = mean( (mpg - pred)^2))
MSE3
Run     MSE
  <chr> <dbl>
1 1      6.45
2 2      2.27
3 3      7.71
4 4      9.56
5 5     15.4 

Please note your machine uses a different random number than mine to construct the folds, your numbers may differ somewhat from mine.

pipe operator in R-Simplify Your Code with %>% »

Finally, consider the following comparison of the three models:

mean(MSE1$MSE)
[1] 7.31312
mean(MSE2$MSE)
[1] 8.277929
mean(MSE2$MSE)
[1] 9.333679

In this case, values are really close however, it appears that model1 is the best model!

apply family in r apply(), lapply(), sapply(), mapply() and tapply() »

You may also like...

Leave a Reply

Your email address will not be published. Required fields are marked *

nine + two =

Ads Blocker Image Powered by Code Help Pro

Quality articles need supporters. Will you be one?

You currently have an Ad Blocker on.

Please support FINNSTATS.COM by disabling these ads blocker.

Powered By
100% Free SEO Tools - Tool Kits PRO