Anusha Mohan, Data Scientist
Cross-validation is like a sharp knife. It cuts great, more can be done with less effort, but if one is not careful, it could injure the user. This blog post details a common mistake which data scientists could make while leveraging cross-validation to improve the performance of a model trained on an imbalanced dataset.
Why Use Cross-Validation?
Cross-validation is a popular technique that data scientists use to validate the stability and generalizability of a machine learning model. In K-fold cross-validation, the data is partitioned into K subsets. We then iteratively train on each set of (K-1) subsets and validate on the remaining subset, allowing each data point to be used as training data (K-1) times and validation data 1 time. The model score is then computed as the average validation score across all K trials.
This helps a data scientist validate that their model is not overfitting and will generalize to unseen data, as the model score is an average across multiple validation scores, not just a single iteration. A data scientist must ensure that the model is learning the patterns in the data and not fitting to the noise. In other words, the model should be low on bias and variance.
An imbalanced dataset is one where a positive signal occurs in only a small proportion of the total dataset. Often, the minority class in such a dataset will carry an extreme risk if it is not properly detected. For example, fraudulent credit card transactions account for only 1–2% of all transactions, but the risk associated with not catching fraudulent activity is very high.
There are several disadvantages when building a model with imbalanced data:
- Bias — The classification output is biased as the classifiers are more sensitive to detecting the majority class and less sensitive to the minority class.
- Optimization Metrics — Normal optimization metrics, such as accuracy, may not be indicative of true performance, especially when there is increased risk associated with false-negative or false-positive predictions.
- Difficulty Getting More Data — Datasets are often imbalanced because of real world factors, as with the credit card example mentioned above. Namely, the minority class is simply a rare event, which makes finding data that would hopefully balance the class distribution of the dataset very difficult.
Because of these disadvantages, there can be caveats when trying to leverage cross-validation with an imbalanced dataset. These caveats are perhaps best illustrated with an example.
An Illustrated Example
The dataset I have picked is a breast cancer prediction dataset from the UCI machine learning repository. This dataset consists of a sample of patients who have been diagnosed with either a malignant tumor or a benign tumor, and 9 features that each contain information on the detected tumor.
Further, I made this dataset severely imbalanced by keeping only a sample of positive instances (malignant tumor) from the original dataset. This pre-processed dataset contains 456 patients, where 2.6% have a malignant tumor (12 patients) and 97.3% have a benign tumor (444 patients). Finally, I randomly sampled 30% of the pre-processed dataset to be held out as a pure test set so that I could evaluate the model against truly unseen data.
All code and data can be found at this GitHub repository.
Model Optimization Metrics
When choosing the correct optimization metric, we must consider the risks associated with the results of our predictions. For example, accuracy would not be the right metric to evaluate a model trained on this dataset. A naive model would classify all data in a test set with the majority class label, which in our case would be label 0 — benign tumor. The accuracy of this naive model would be very high, but every patient with a malignant tumor would be incorrectly diagnosed. In healthcare, the worst case scenario is a false-negative diagnosis. We want to identify as many breast cancer patients as possible to provide the care and treatment they deserve. Therefore this naive implementation would be unacceptable.
Since we want to penalize our model as much as possible for false-negative predictions, recall is the right metric to evaluate our model.
When a random forest model was trained on this imbalanced dataset and applied to the test set, the accuracy obtained was 97%, with precision and recall scores 0%. We would like to improve upon this baseline test-set recall score, as our objective is to identify as many patients with malignant tumors as possible.
How Do We Improve Performance?
The recall and precision scores were low as there was not enough data for the model to pick up the signal for malignant tumors. No observations in the test set were classified as 1’s, which resulted in 0 overall positives (true and false positives) and 4 false-negatives.
One thing we can do to improve our performance is to balance the dataset. We have two options to do this:
- Undersampling the majority class — Undersampling involves reducing the number of samples from the majority class by randomly selecting a subset of data points from that class to use for training. One of the major disadvantages of performing undersampling is that useful data or information might be thrown away.
- Oversampling the minority class — Oversampling involves increasing the number of the samples from the minority class in the training dataset. The common method is to add copies of data points from the minority class, which amplifies the decision region resulting in the improvement of evaluation metrics. The main disadvantage of this method is that it might result in overfitting. Another method to oversample, which reduces this issue, is SMOTE (Synthetic Minority Oversampling Technique). SMOTE is an enhanced sampling method that creates synthetic samples based on the nearest neighbors of feature values in the minority class.
Since our sample size is already small, we will use oversampling to try to improve test set performance.
Cross-Validation on Oversampled Data (using SMOTE)
The minority class in the training set was oversampled to a class ratio of 1 to improve the quality of model predictions by enhancing the decision boundary. 5-fold cross-validation was then performed using this balanced training set in order to get an estimate of test set performance. Taking the average recall score across all 5 iterations, we get a model score of 100%!
100% cross-validated recall score looks pretty great without any feature engineering or hyper-parameter tuning. The only difference was the minority class in the dataset was oversampled to a ratio of 1. It is almost too good to be true…
And we are right. When a model trained with the balanced training set is applied to the test set, the test recall score is only 50%.
So what happened? Cross-validation was supposed to represent test set performance, but it didn’t work in this case. Did my model overfit?
First we can check the feature distributions in the training set. When we plot the kernel densities, we see that only 2 out of 9 features show high discriminative power between class 0 and class 1.
This makes us even more suspicious of our cross-validation process, as it is hard to overfit as heavily as indicated when only two features have high predictive power.
The real reason our test recall score was far less than the cross-validated score was because of information bleed from the validation set to the training set in each iteration. Let’s think for a moment about how the oversampling was done. I split the data first into training and test, followed by oversampling the minority class in the training data using SMOTE, then passing that balanced dataset through cross-validation. During cross-validation, the already oversampled training data was further divided into train and validation sets. This causes two problems:
- Synthetic observations could end up in both the training and validation sets during the same iteration. A sophisticated model architecture like Random Forest will be able to recognize these feature values are from the same distribution and predict accurately on validation, thereby inflating recall scores.
- The model training and application process are not representative of the testing environment. In each iteration of cross-validation, the model is both trained and applied on balanced datasets, while in reality it will be trained on a balanced dataset and applied on an imbalanced dataset.
The ideal way to conduct this experiment would be to perform oversampling of the training set in each cross-validation iteration, not before beginning the process. This prevents the data leakage from the validation set to the training set during cross-validation, and reflects how the model trained on a balanced training set would perform when applied to an imbalanced, unseen test set.
Now that we can rely on our cross-validation process to be indicative of test set performance, we can continue to try to improve model performance by feature engineering and hyper-parameter tuning.
One of the main challenges data scientists face is working with imbalanced datasets, where predicting the minority class is often difficult. Techniques like oversampling/SMOTE help to improve model performance trained on these datasets, but in order to ensure cross-validation is helping us produce the best model that will generalize to unseen data, it is essential to follow these guidelines:
- In each iteration exclude some data for validation. The excluded data should not be used for feature selection, oversampling and model building.
- Oversample the minority class only in the training set without the data already excluded for validation.
- Repeat K times, where K is number of folds.
This will make sure each iteration of cross-validation is representative of the real-world environment, in both training a model on balanced data and applying to an imbalanced validation set.
1. O. L. Mangasarian and W. H. Wolberg: “Cancer diagnosis via linear programming”, SIAM News, Volume 23, Number 5, September 1990, pp 1 & 18.
2. William H. Wolberg and O.L. Mangasarian: “Multisurface method of pattern separation for medical diagnosis applied to breast cytology”, Proceedings of the National Academy of Sciences, U.S.A., Volume 87, December 1990, pp 9193–9196.