caret
This notebook describes an example of using the caret
1 package to conduct hyperparameter tuning for the k-Nearest Neighbour classifier.
The example dataset is the banknote
dataframe found in the mclust
2 package. It contains six measurements made on 100 genuine and 100 counterfeit old-Swiss 1000-franc bank notes.
There are six predictor variables (Length
, Left
, Right
, Bottom
, Top
, Diagonal
) with Status
being the categorical response or class variable having two levels, namely genuine
and counterfeit
.
Observe that the dataset is balanced with 100 observations against each level of Status
.
banknote %>%
group_by(Status) %>%
summarise(N = n(),
Mean_Length = mean(Length),
Mean_Left = mean(Left),
Mean_Right = mean(Right),
Mean_Bottom = mean(Bottom),
Mean_Top = mean(Top),
Mean_Diagonal = mean(Diagonal),
.groups = "keep")
In most of the measurements of bank notes aside from Length
, genuine and counterfeit notes have quite distinct distributions.
library(tidyr)
banknote %>%
mutate(ID = 1:n()) %>%
pivot_longer(Length:Diagonal,
names_to = "Dimension",
values_to = "Size") %>%
mutate(Dimension = factor(Dimension),
ID = factor(ID)) %>%
ggplot() +
aes(y = Size, fill = Status) +
facet_wrap(~ Dimension, scales = "free") +
geom_boxplot() +
theme(axis.text.x = element_blank(),
axis.ticks.x = element_blank()) +
labs(y = "Size (mm)", title = "Comparison of bank note dimensions")
Below is a visualisation of the distribution of the perimeters of the bank notes.
banknote %>%
mutate(Perimeter = 2*Length + Left + Right) %>%
ggplot() +
aes(x = Perimeter, fill = Status) +
geom_density(alpha = 0.5) +
labs(x = "Perimeter (mm)", y = "Density", title = "Distribution of banknote perimeters")
Create training and testing datasets, preserving the 50/50 class split in each.
set.seed(1)
training_index <- createDataPartition(banknote$Status,
p = 0.8,
list = FALSE)
training_set <- banknote[training_index, ]
testing_set <- banknote[-training_index, ]
We can confirm the class split in the training set:
##
## counterfeit genuine
## 80 80
Set up the cross-validation for hyperparameter tuning, i.e., 10-fold cross validation repeated 10 times.
The summaryFunction
argument determines which metric to use to determine the performance of a particular hyperparameter setting. Here we shall use defaultSummary
which calculates accuracy and kappa statistic.
training_control <- trainControl(method = "repeatedcv",
summaryFunction = defaultSummary,
classProbs = TRUE,
number = 10,
repeats = 10)
Now use the train()
function to perform the model training/tuning of the k
hyperparameter.
The range of k
is from 3 to 31 in steps of 2, i.e., odd distances only.
set.seed(2)
knn_cv <- train(Status ~ .,
data = training_set,
method = "knn",
trControl = training_control,
metric = "Accuracy",
tuneGrid = data.frame(k = seq(11,85,by = 2)))
knn_cv
## k-Nearest Neighbors
##
## 160 samples
## 6 predictor
## 2 classes: 'counterfeit', 'genuine'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 10 times)
## Summary of sample sizes: 144, 144, 144, 144, 144, 144, ...
## Resampling results across tuning parameters:
##
## k Accuracy Kappa
## 11 0.993750 0.98750
## 13 0.993750 0.98750
## 15 0.996875 0.99375
## 17 0.996875 0.99375
## 19 0.995000 0.99000
## 21 0.996875 0.99375
## 23 0.998125 0.99625
## 25 0.998125 0.99625
## 27 1.000000 1.00000
## 29 1.000000 1.00000
## 31 1.000000 1.00000
## 33 1.000000 1.00000
## 35 1.000000 1.00000
## 37 1.000000 1.00000
## 39 1.000000 1.00000
## 41 1.000000 1.00000
## 43 1.000000 1.00000
## 45 1.000000 1.00000
## 47 1.000000 1.00000
## 49 1.000000 1.00000
## 51 1.000000 1.00000
## 53 1.000000 1.00000
## 55 1.000000 1.00000
## 57 1.000000 1.00000
## 59 1.000000 1.00000
## 61 1.000000 1.00000
## 63 1.000000 1.00000
## 65 1.000000 1.00000
## 67 1.000000 1.00000
## 69 1.000000 1.00000
## 71 1.000000 1.00000
## 73 1.000000 1.00000
## 75 0.999375 0.99875
## 77 0.998750 0.99750
## 79 0.998125 0.99625
## 81 0.996875 0.99375
## 83 0.995000 0.99000
## 85 0.991875 0.98375
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was k = 73.
The cross-validation on the training set has tuned a k
parameter of 73.
Inspecting the probabilities reveals that a cutoff probability around 0.5 give good classification results.
training_set <- training_set %>%
mutate(Predicted_prob = predict(knn_cv, type = "prob")$genuine)
training_set %>%
ggplot() +
aes(x = Predicted_prob, fill = Status) +
geom_histogram(bins = 20) +
labs(x = "Probability", y = "Count", title = "Distribution of predicted probabilities" )
An ROC curve is another way to visualise the results and identify a good cutoff.
pROC_train <- roc(training_set$Status, training_set$Predicted_prob,
quiet = TRUE,
plot = TRUE,
percent = TRUE,
auc.polygon = TRUE,
print.auc = TRUE,
print.thres = TRUE,
print.thres.best.method = "youden")
According to the Youden criterion on the training set, the best threshold is 0.5. Choosing this as the cutoff probability returns a perfect classification result on the training data. Be wary of overfitting the training data however.
Apply the final model, with k = 73 and cutoff = 0.5, to the testing dataset to get an estimate of the true performance of this classifier.
knn_predictions <- predict(knn_cv, newdata = testing_set, type = "prob") %>%
select(probability = genuine) %>%
mutate(class = ifelse(probability > 0.5, "genuine", "counterfeit")) %>%
mutate(class = factor(class))
The results on the testing dataset are evenly split between the two classes which is a good sign!
##
## counterfeit genuine
## 20 20
Since we have the ground truth data, we can use the confusionMatrix()
function to report full set of performance statistics.
## Confusion Matrix and Statistics
##
## Reference
## Prediction counterfeit genuine
## counterfeit 20 0
## genuine 0 20
##
## Accuracy : 1
## 95% CI : (0.9119, 1)
## No Information Rate : 0.5
## P-Value [Acc > NIR] : 9.095e-13
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Sensitivity : 1.0
## Specificity : 1.0
## Pos Pred Value : 1.0
## Neg Pred Value : 1.0
## Precision : 1.0
## Recall : 1.0
## F1 : 1.0
## Prevalence : 0.5
## Detection Rate : 0.5
## Detection Prevalence : 0.5
## Balanced Accuracy : 1.0
##
## 'Positive' Class : counterfeit
##
Indeed we have achieved perfect classification with this kNN classifier!
Max Kuhn (2020). caret: Classification and Regression Training. R package version 6.0-86. https://CRAN.R-project.org/package=caret↩︎
Scrucca L., Fop M., Murphy T. B. and Raftery A. E. (2016) mclust 5: clustering, classification and density estimation using Gaussian finite mixture models The R Journal 8/1, pp. 289-317↩︎