In many real-world classification problems, we stumble upon training data with unbalanced classes. This means that the individual classes do not contain the same number of elements. For example, if we want to build an image-based skin cancer detection system using convolutional neural networks, we might encounter a dataset with about 95% negatives and 5% positives. This is for good reasons: Images associated with a negative diagnosis are way more common than images with a positive diagnosis. Rather than regarding this as a flaw in the dataset, we should leverage the additional information that we get. This blog post will show you how.
Unbalanced classes create two problems:
- The accuracy (i.e. ratio of test samples for which we predicted the correct class) is no longer a good measure of the model performance. A model that just predicts “not cancer” everytime will yield a 95% accuracy, even though it is a bad (and even dangerous) model that does not yield any insight or scientific advancement, despite the fact that “95% accuracy” sounds like something good. In addition, it’s hard to get an intuition for how good a model with 96%, 97% or 98% accuracy really is.
- The training process might arrive at a local optimum that always predicts “not cancer”, making it hard to further improve the model.
Fortunately, these problems are not so difficult to solve. Here are a few ways to tackle them.
1. Collect more data
If possible, you could collect more data for the underrepresented classes to match the number of samples in the overrepresented classes. This is probably the most rewarding approach, but it is also the hardest and most time-consuming, if not downright impossible. In the cancer example, there is a good reason that we have way more non-cancer samples than cancer samples: These are easier to obtain, since there are more people in the world who haven’t developed cancer.
2. Create copies of training samples
Artificially increase the number of training samples for the underrepresented classes by creating copies. While this is the easiest solution, it wastes time and computing resources. In the cancer example, we would almost have to double the size of the dataset in order to achieve a 50:50 share between the classes, which also doubles training time without adding any new information.
3. Create augmented copies of training samples
Similar to 2, but create augmented copies of the underrepresented classes. For example, in the case of images, create slightly rotated, shifted or flipped versions of the original images. This has the positive side-effect of making the model more robust to unseen examples. However, it only does so for the underrepresented classes. Ideally, you would want to do this for all classes, but then the classes are unbalanced again and we’re back where we started.
4. Train for sensitivity and specificity
The sensitivity tells us the probability that we detect cancer, given that the patient really has cancer. It is thus a measure of how good we are at correctly diagnosing people who have cancer.
$$sensitivity = Pr(detect\, cancer \; \vert \; cancer) = \frac{\text{true positives}}{\text{positives}}$$
The specificity tells us the probability that we do not detect cancer, given that the patient doesn’t have cancer. It measures how good we are at not causing people to believe that they have cancer if in fact they do not.
$$specificity = Pr(\lnot \, detect\, cancer \; \vert \; \lnot \, cancer) = \frac{\text{true negatives}}{\text{negatives}}$$
A model that always predicts cancer will have a sensitivity of 1 and a specificity of 0. A model that never predicts cancer will have a sensitivity of 0 and a specificity of 1. An ideal model should have both a sensitivity of 1 and a specificity of 1. In reality, however, this is unlikely to be achievable. Therefore, we should look for a model that achieves a good tradeoff between specificity and sensitivity. So which one of the two is more important? This can’t be said in general. It highly depends on the application.
If you build a photo-based skin cancer detection app, then a high sensitivity is probably more important than a high specificity, since you want to cause people who might have cancer to get themselves checked by a doctor. Specificity is a little less important here, but still, if you detect cancer too often, people might stop using your app since they unnecessarily get annoyed and scared.
Now suppose that our desired tradeoff between sensitivity and specificity is given by a number $t \in [0, 1]$ where $t = 1$ means that we only pay attention to sensitivity, $t = 0$ means we only pay attention to specificity and $t = 0.5$ means that we regard both to be equally important. In order to incorporate the desired tradeoff into the training process, we need the samples of the different classes to have a different contribution to the loss. To achieve this, we can simply multiply the contribution of the cancer samples to the loss by
$$\frac{\text{number of non-cancer samples}}{\text{number of cancer samples}} \cdot t$$
Keras Implementation
In Keras, the class weights can easily be incorporated into the loss by adding the following parameter to the fit function (assuming that 1 is the cancer class):
class_weight={ 1: n_non_cancer_samples / n_cancer_samples * t }
Now, while we train, we want to monitor the sensitivity and specificity. Here is how to do this in Keras. In other frameworks, the implementation should be similar (for instance, you could replace all the K calls by numpy calls).
from keras import backend as K def sensitivity(y_true, y_pred): true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) possible_positives = K.sum(K.round(K.clip(y_true, 0, 1))) return true_positives / (possible_positives + K.epsilon()) def specificity(y_true, y_pred): true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1))) possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1))) return true_negatives / (possible_negatives + K.epsilon())
model.compile( loss='binary_crossentropy', optimizer=RMSprop(0.001), metrics=[sensitivity, specificity] )
Generalizing to more than 2 classes
If we have more than two classes, we can generalize sensitivity and specificity to a “per-class accuracy”:
$$perClassAccuracy(C) = Pr(detect\, C \; \vert \; C)$$
In order to train for maximum per-class accuracy, we have to specify class weights that are inversely proportional to the size of the class:
class_weight={ 0: 1.0/n_samples_0, 1: 1.0/n_samples_1, 2: 1.0/n_samples_2, ... }
Here is a Keras implementation of the per-class accuracy, which I adopted from jdehesa at Stackoverflow.
INTERESTING_CLASS_ID = 0 # Choose the class of interest def single_class_accuracy(y_true, y_pred): class_id_true = K.argmax(y_true, axis=-1) class_id_preds = K.argmax(y_pred, axis=-1) accuracy_mask = K.cast(K.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32') class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1) return class_acc
If you have any questions, feel free to leave a comment.
Hi, I am interested in this formula
class_weight={
1: n_non_cancer_samples / n_cancer_samples * t
}
Would you have any reference to talk about how the weight affect the results of the training?
Regarding “Keras Implementation”, Keras’ fit() function’s parameter “class_weight” needs a weight for each class to be specified.
In my example, class 0 is “cancer” while class 1 is “non cancer” and I want to pay more attention to sensitivity instead of specificity.
I used
t=0.8
class_weights={
0: (n_non_cancer_samples / n_cancer_samples) * t,
1: (n_non_cancer_samples / n_cancer_samples) * (1-t)
}
Is that correct?
Edit: I’m using HAM10K dataset, in which only about 10% of the images are Melanoma. I first used Augmentor library to create augmented copies of the originals Melanoma images, then I used class weights as above to put more emphasis on sensitivity (n_non_cancer_samples=n_cancer_samples in my case due to the augmentation, so Melanoma weights 0.8 and Not Melanoma weights 0.2). Should I use the aboves class weights without pre-augmentation, directly with unbalanced dataset?
Hi! I would recommend to additionally multiply your class weights by the number of samples. The tradeoff between sensitivity and specificity can then be controlled with your t parameter.
Whenever possible, I would still use augmentation. But the class weights should solve the problem of unbalanced classes.
Regarding part 3: Create augmented copies of training samples.
I’ve been trying to do this myself, but I can’t find any guides on how to use image augmentation on one class only. Most guides on imbalance I’ve seen recommend adjusting class weights. Coincidentally I’m also using a skin cancer dataset.
Any tips you could share or guides you know of for augmenting the minority class?
Thanks
I had the same problem and I used the Augmentor library: https://github.com/mdbloice/Augmentor
Obviously I divided Melanoma and Not Melanoma images in 2 different directories in advance and I used Augmentor only on Melanoma images.
With respect to the definition of true positives, are you sure this is correct? By multiplying y_true * y_pred first, and only then using K.round, you get artificially less TPs, e.g. true=0.7 and pred=0.6, their multiplication would not be “true” (0.42)
So I think this only works when you’re sure that y_true is already rounded to 0’s and 1’s.
You’re right, but y_true is in fact always either 0 or 1, since it reflects the binary true class from the training data. I agree, though, that it might be more intuitive to round y_pred first and then multiply by y_true. But the result is the same.
Thanks for this great article!
My question is: When is the best time to stop training? I am using a metric which is the mean of the sensitivity and specificity. If I train the model for 100 epochs and select the one with the best average score, it is usually not the model with the least validation accuracy. So it may be considered that there is some kind of overfitting.
So is it a good approach to select the model with the best average score, but train until the least validation accuracy (with a patience of, say 5) ?