— Neural Networks, Deep Learning, TensorFlow, Machine Learning, Python — 6 min read
Share
TL;DR Learn how to handle imbalanced data using TensorFlow 2, Keras and scikit-learn
Datasets in the wild will throw a variety of problems towards you. What are the most common ones?
The data might have too few examples, too large to fit into the RAM, multiple missing values, do not contain enough predictive power to make correct predictions, and it can imbalanced.
In this guide, we’ll try out different approaches to solving the imbalance issue for classification tasks. That isn’t the only issue on our hands. Our dataset is real, and we’ll have to deal with multiple problems - imputing missing data and handling categorical features.
Before getting any deeper, you might want to consider far simpler solutions to the imbalanced dataset problem:
Here’s what you’ll learn:
Run the complete code in your browser
Naturally, our data should be imbalanced. Kaggle has the perfect one for us - Porto Seguro’s Safe Driver Prediction. The object is to predict whether a driver will file an insurance claim. How many drivers do that?
Let’s start with installing TensorFlow and setting up the environment:
1!pip install tensorflow-gpu2!pip install gdown
1import numpy as np2import tensorflow as tf3from tensorflow import keras4import pandas as pd56RANDOM_SEED = 4278np.random.seed(RANDOM_SEED)9tf.random.set_seed(RANDOM_SEED)
We’ll use gdown to get the data from Google Drive:
1!gdown --id 18gwvNkMs6t0jL0APl9iWPrhr5GVg082S --output insurance_claim_prediction.csv
Let’s load the data in Pandas and have a look:
1df = pd.read_csv('insurance_claim_prediction.csv')2print(df.shape)j
1(595212, 59)
Loads of data. What features does it have?
1print(df.columns)
1Index(['id', 'target', 'ps_ind_01', 'ps_ind_02_cat', 'ps_ind_03',2 'ps_ind_04_cat', 'ps_ind_05_cat', 'ps_ind_06_bin', 'ps_ind_07_bin',3 'ps_ind_08_bin', 'ps_ind_09_bin', 'ps_ind_10_bin', 'ps_ind_11_bin',4 'ps_ind_12_bin', 'ps_ind_13_bin', 'ps_ind_14', 'ps_ind_15',5 'ps_ind_16_bin', 'ps_ind_17_bin', 'ps_ind_18_bin', 'ps_reg_01',6 'ps_reg_02', 'ps_reg_03', 'ps_car_01_cat', 'ps_car_02_cat',7 'ps_car_03_cat', 'ps_car_04_cat', 'ps_car_05_cat', 'ps_car_06_cat',8 'ps_car_07_cat', 'ps_car_08_cat', 'ps_car_09_cat', 'ps_car_10_cat',9 'ps_car_11_cat', 'ps_car_11', 'ps_car_12', 'ps_car_13', 'ps_car_14',10 'ps_car_15', 'ps_calc_01', 'ps_calc_02', 'ps_calc_03', 'ps_calc_04',11 'ps_calc_05', 'ps_calc_06', 'ps_calc_07', 'ps_calc_08', 'ps_calc_09',12 'ps_calc_10', 'ps_calc_11', 'ps_calc_12', 'ps_calc_13', 'ps_calc_14',13 'ps_calc_15_bin', 'ps_calc_16_bin', 'ps_calc_17_bin', 'ps_calc_18_bin',14 'ps_calc_19_bin', 'ps_calc_20_bin'],15 dtype='object')
Those seem somewhat cryptic, here is the data description:
features that belong to similar groupings are tagged as such in the feature names (e.g., ind, reg, car, calc). In addition, feature names include the postfix bin to indicate binary features and cat to indicate categorical features. Features without these designations are either continuous or ordinal. Values of -1 indicate that the feature was missing from the observation. The target columns signifies whether or not a claim was filed for that policy holder.
What is the proportion of each target class?
1no_claim, claim = df.target.value_counts()2print(f'No claim {no_claim}')3print(f'Claim {claim}')4print(f'Claim proportion {round(percentage(claim, claim + no_claim), 2)}%')
1No claim 5735182Claim 216943Claim proportion 3.64%
Good, we have an imbalanced dataset on our hands. Let’s look at a graphical representation of the imbalance:
You got the visual proof right there. But how good of a model can you build using this dataset?
You might’ve noticed something in the data description. Missing data points have a value of -1. What should we do before training our model?
Let’s check how many rows/columns contain missing data:
1row_count = df.shape[0]23for c in df.columns:4 m_count = df[df[c] == -1][c].count()5 if m_count > 0:6 print(f'{c} - {m_count} ({round(percentage(m_count, row_count), 3)}%) rows missing')
1ps_ind_02_cat - 216 (0.036%) rows missing2ps_ind_04_cat - 83 (0.014%) rows missing3ps_ind_05_cat - 5809 (0.976%) rows missing4ps_reg_03 - 107772 (18.106%) rows missing5ps_car_01_cat - 107 (0.018%) rows missing6ps_car_02_cat - 5 (0.001%) rows missing7ps_car_03_cat - 411231 (69.09%) rows missing8ps_car_05_cat - 266551 (44.783%) rows missing9ps_car_07_cat - 11489 (1.93%) rows missing10ps_car_09_cat - 569 (0.096%) rows missing11ps_car_11 - 5 (0.001%) rows missing12ps_car_12 - 1 (0.0%) rows missing13ps_car_14 - 42620 (7.16%) rows missing
ps_car_03_cat
, ps_car_05_cat
and ps_reg_03
have too many missing rows for our own comfort. We’ll get rid of them. Note that this is not the best strategy but will do in our case.
1df.drop(2 ["ps_car_03_cat", "ps_car_05_cat", "ps_reg_03"],3 inplace=True,4 axis=15)
What about the other features? We’ll use the SimpleImputer from scikit-learn to replace the missing values:
1from sklearn.impute import SimpleImputer23cat_columns = [4 'ps_ind_02_cat', 'ps_ind_04_cat', 'ps_ind_05_cat',5 'ps_car_01_cat', 'ps_car_02_cat', 'ps_car_07_cat',6 'ps_car_09_cat'7]8num_columns = ['ps_car_11', 'ps_car_12', 'ps_car_14']910mean_imp = SimpleImputer(missing_values=-1, strategy='mean')11cat_imp = SimpleImputer(missing_values=-1, strategy='most_frequent')1213for c in cat_columns:14 df[c] = cat_imp.fit_transform(df[[c]]).ravel()1516for c in num_columns:17 df[c] = mean_imp.fit_transform(df[[c]]).ravel()
We use the most frequent value for categorical features. Numerical features are replaced with the mean number of the column.
Pandas get_dummies()
uses one-hot encoding to represent categorical features. Perfect! Let’s use it:
1df = pd.get_dummies(df, columns=cat_columns)
Now that we don’t have more missing values (you can double-check that) and categorical features are encoded, we can try to predict insurance claims. What accuracy can we get?
We’ll start by splitting the data into train and test datasets:
1from sklearn.model_selection import train_test_split23labels = df.columns[2:]45X = df[labels]6y = df['target']78X_train, X_test, y_train, y_test = \9 train_test_split(X, y, test_size=0.05, random_state=RANDOM_SEED)
Our binary classification model is a Neural Network with batch normalization and dropout layers:
1def build_model(train_data, metrics=["accuracy"]):2 model = keras.Sequential([3 keras.layers.Dense(4 units=36,5 activation='relu',6 input_shape=(train_data.shape[-1],)7 ),8 keras.layers.BatchNormalization(),9 keras.layers.Dropout(0.25),10 keras.layers.Dense(units=1, activation='sigmoid'),11 ])1213 model.compile(14 optimizer=keras.optimizers.Adam(lr=0.001),15 loss=keras.losses.BinaryCrossentropy(),16 metrics=metrics17 )1819 return model
You should be familiar with the training procedure:
1BATCH_SIZE = 204823model = build_model(X_train)4history = model.fit(5 X_train,6 y_train,7 batch_size=BATCH_SIZE,8 epochs=20,9 validation_split=0.05,10 shuffle=True,11 verbose=012)
In general, you should strive for a small batch size (e.g. 32). Our case is a bit specific - we have highly imbalanced data, so we’ll give a fair chance to each batch to contain some insurance claim data points.
The validation accuracy seems quite good. Let’s evaluate the performance of our model:
1model.evaluate(X_test, y_test, batch_size=BATCH_SIZE)
1119043/119043 - loss: 0.1575 - accuracy: 0.9632
That’s pretty good. It seems like our model is pretty awesome. Or is it?
1def awesome_model_predict(features):2 return np.full((features.shape[0], ), 0)34y_pred = awesome_model_predict(X_test)
This amazing model predicts that there will be no claim, no matter the features. What accuracy does it get?
1from sklearn.metrics import accuracy_score23accuracy_score(y_pred, y_test)
10.9632
Sweet! Wait. What? This is as good as our complex model. Is there something wrong with our approach?
Not really. We’re just using the wrong metric to evaluate our model. This is a well-known problem. The Accuracy paradox suggests accuracy might not be the correct metric when the dataset is imbalanced. What can you do?
One way to understand the performance of our model is to use a confusion matrix. It shows us how well our model predicts for each class:
When the model is predicting everything perfectly, all values are on the main diagonal. That’s not the case. So sad! Our complex model seems as dumb as dumb as our awesome model.
Good, now we know that our model is very bad at predicting insurance claims. Can we somehow tune it to do better?
We can use a wide range of other metrics to measure our peformance better:
Low precision indicates a high number of false positives.
Low recall indicates a high number of false negatives.
ROC curve - A curve of True Positive Rate vs. False Positive Rate at different classification thresholds. It starts at (0,0) and ends at (1,1). A good model produces a curve that goes quickly from 0 to 1.
AUC (Area under the ROC curve) - Summarizes the ROC curve with a single number. The best value is 1.0, while 0.5 is the worst.
Different combinations of precision and recall give you a better understanding of how well your model is performing for a given class:
Luckily, Keras can calculate most of those metrics for you:
1METRICS = [2 keras.metrics.TruePositives(name='tp'),3 keras.metrics.FalsePositives(name='fp'),4 keras.metrics.TrueNegatives(name='tn'),5 keras.metrics.FalseNegatives(name='fn'),6 keras.metrics.BinaryAccuracy(name='accuracy'),7 keras.metrics.Precision(name='precision'),8 keras.metrics.Recall(name='recall'),9 keras.metrics.AUC(name='auc'),10]
And here are the results:
1loss : 0.15572532432133232tp : 0.03fp : 1.04tn : 57302.05fn : 2219.06accuracy : 0.96270297precision : 0.08recall : 0.09auc : 0.6202165510f1 score: 0.0
Here is the ROC:
Our model is complete garbage. And we can measure how much garbage it is. Can we do better?
We have many more examples of no insurance claims compared to those claimed. Let’s force our model to pay attention to the underrepresented class. We can do that by passing weights for each class. First we need to calcualte those:
1no_claim_count, claim_count = np.bincount(df.target)2total_count = len(df.target)34weight_no_claim = (1 / no_claim_count) * (total_count) / 2.05weight_claim = (1 / claim_count) * (total_count) / 2.067class_weights = {0: weight_no_claim, 1: weight_claim}
Now, let’s use the weights when training our model:
1model = build_model(X_train, metrics=METRICS)23history = model.fit(4 X_train,5 y_train,6 batch_size=BATCH_SIZE,7 epochs=20,8 validation_split=0.05,9 shuffle=True,10 verbose=0,11 class_weight=class_weights12)
Let’s begin with the confusion matrix:
Things are a lot different now. We have a lot of correctly predicted insurance claims. The bad news is that we have a lot of predicted claims that were no claims. What can our metrics tell us?
1loss : 0.66944034633479132tp : 642.03fp : 11170.04tn : 17470.05fn : 479.06accuracy : 0.60858177precision : 0.054351518recall : 0.572702949auc : 0.6310465310f1 score: 0.09928090930178612
The recall has jumped significantly while the precision bumped up only slightly. The F1-score is pretty low too! Overall, our model has improved somewhat. Especially, considering the minimal effort on our part. How can we do better?
These methods try to “correct” the balance in your data. They act as follows:
Naturally, a classifier trained on the “rebalanced” data will not know the original proportions. It is expected to have (much) lower accuracy since true proportions play a role in making a prediction.
You must think long and hard (that’s what she said) before using resampling methods. It can be a perfectly good approach or complete nonsense.
Let’s start by separating the classes:
1X = pd.concat([X_train, y_train], axis=1)23no_claim = X[X.target == 0]4claim = X[X.target == 1]
We’ll start by adding more copies from the “insurance claim” class. This can be a good option when the data is limited. Either way, you might want to evaluate all approaches using your metrics.
We’ll use the resample()
utility from scikit-learn:
1from sklearn.utils import resample23claim_upsampled = resample(claim,4 replace=True,5 n_samples=len(no_claim),6 random_state=RANDOM_SEED)
Here is the new distribution of no claim vs claim:
Our new model performs like this:
1loss : 0.61236141187714242tp : 530.03fp : 8754.04tn : 19886.05fn : 591.06accuracy : 0.685998447precision : 0.0570874628recall : 0.472792159auc : 0.627425810f1 score: 0.10187409899086977
The performance of our model is similar to the weighted one. Can undersampling do better?
We’ll remove samples from the no claim class and balance the data this way. This can be a good option when your dataset is large. Removing data can lead to underfitting on the test set.
1no_claim_downsampled = resample(no_claim,2 replace = False,3 n_samples = len(claim),4 random_state = RANDOM_SEED)
1loss : 0.63770139924757532tp : 544.03fp : 8969.04tn : 19671.05fn : 577.06accuracy : 0.679244647precision : 0.0571849058recall : 0.4852819auc : 0.620633910f1 score: 0.1023133345871732
Again, we don’t have such impressive results but doing better than the baseline model.
Let’s try to simulate the data generation process by creating synthetic samples. We’ll use the imbalanced-learn library to do that.
One over-sampling method to generate synthetic data is the Synthetic Minority Oversampling Technique (SMOTE). It uses KNN algorithm to generate new data samples.
1from imblearn.over_sampling import SMOTE23sm = SMOTE(random_state=RANDOM_SEED, ratio=1.0)4X_train, y_train = sm.fit_sample(X_train, y_train)
1loss : 0.260400014176836062tp : 84.03fp : 1028.04tn : 27612.05fn : 1037.06accuracy : 0.93061397precision : 0.075539578recall : 0.07493319auc : 0.561122910f1 score: 0.07523510971786834
We have high accuracy but very low precision and recall. Not a useful approach for our dataset.
There are a lot of ways to handle imbalanced datasets. You should always start with something simple (like collecting more data or using a Tree-based model) and evaluate your model with the appropriate metrics. If all else fails, come back to this guide and try the more advanced approaches.
You learned how to:
Run the complete code in your browser
Remember that the best approach is almost always specific to the problem at hand (context is king). And sometimes, you can restate the problem as outlier/anomaly detection ;)
Share
You'll never get spam from me