The original link: amaarora. Making. IO / 2020/06/29 /…

By Aman Arora

Focal Loss is a loss function commonly used in target detection. Recently SAW a blog, take this opportunity to learn and translate, and share with you.

In this blog post, you’ll understand what Focal Loss is and when to use it. We’ll also take a closer look at the mathematics behind it and the PyTorch implementation.

  1. What is Focal Loss and what is it for?
  2. Why does Focal Loss work and how does it work?
  3. Alpha and Gamma?
  4. How do you implement it in code?
  5. Credits

What is Focal Loss and what is it for?

Before we look at what Focal Loss is and all the details about it, let’s take a quick and intuitive look at what it does. Focal Loss was first implemented by He et al in his paper Focal Loss for Dense Object Detection.

Before this article was published, object detection was actually considered a very difficult problem to solve, especially for detecting small objects in images. See the example below. The size of the motorcycle is relatively small compared to other images, so the model does not do a good job of predicting the existence of the motorcycle.

fig-1  bce \

In the figure above, the reason the model can’t predict the motorcycle is because the model uses the Binary Cross Entropy loss, a training objective that requires the model to be really confident in its prediction. What Focal Loss does is it makes the model more “relaxed” in predicting things without having to be 80-100% sure that the object is “something.” In short, it gives the model more freedom to take some risks when making predictions. This is especially important when dealing with highly imbalanced data sets, because in some cases (such as cancer tests), even false positives are acceptable and the model really needs to take risks and try to predict.

For this reason, Focal Loss is especially useful when samples are unbalanced. Especially in the case of “object detection”, most pixels are usually the background and only a few pixels in the image have objects of interest.

This is the prediction of the same picture by the same model after Focal Loss training.

fig-2  focal loss prediction

It might be a good idea to analyze the two and observe the differences. This will help us visualize Focal Loss.

So why does Focal Loss work, and how does it work?

Now that we’ve seen an example of what “Focal Loss” can do, let’s try to understand why it works. Here’s the most important picture to understand Focal Loss:

fig-3 FL vs CE

In the figure above, the “blue” line represents the cross entropy loss. The X-axis is the “probability of predicting a true tag” (called PT for simplicity). For example, suppose the model predicts a probability of 0.6 that something is a bicycle, and it is indeed a bicycle, in which case the PT is 0.6. And if the same situation is not a bicycle. The pt is 0.4, because the real tag here is 0, and the probability that the object is not a bicycle is 0.4 (1-0.6).

The Y-axis is the loss values of Focal Loss and CE after a given PT.

As can be seen from the image, when the probability of the model predicting a true label is about 0.6, the cross entropy loss is still about 0.5. Therefore, in order to reduce losses during training, our model will have to predict real tags with a higher probability. In other words, the cross entropy loss requires the model to be very confident in its predictions. But this can also have a negative impact on model performance.

Deep learning models become overconfident, so the generalization ability of the model decreases.

The problem of overconfidence in this model is also described in another excellent paper Beyond Temperature Scaling: Obtaining well-bone multiclass probabilities with Dirichlet calibration has been highlighted.

In addition, label smoothing, introduced as part of the initial architecture for rethinking computer vision, is another approach to this problem.

Focal Loss is different from the above solution. As can be seen from the diagram comparing Focal Loss with CrossEntropy, when using γ> 1 Focal Loss can reduce the training loss of “well classified samples” or samples with “high probability of correct model prediction”, while for “difficult to classify samples”, such as those with prediction probability less than 0.5, It doesn’t reduce the loss much. Therefore, in the case of imbalanced data categories, the model will focus on the rare categories, because the samples of these categories are relatively rare and difficult to distinguish.

Focal Loss is mathematically defined as follows:

Alpha and Gamma?

So what are alpha and gamma in Focal Loss? We’ll call alpha alpha and gamma gamma.

We can think of Fig3 this way

γ controls the shape of the curve. The larger the value of γ is, the smaller the loss of samples for good classification will be, so we can focus the model’s attention on those samples that are difficult to classify. A large γ enlarges the sample range for obtaining small loss.

Also, when gamma =0, this expression degenerates into a Cross Entropy Loss, as we all know

Define “PT” as follows, in its true sense:

By combining these two expressions, the Cross Entropy Loss becomes the following.

Now that we know what gamma does, what does alpha do?

In addition to Focal Loss, another way to deal with category imbalance is to introduce weights. Give high weights to rare categories and low weights to dominant or common classes. These weights can also be expressed by alpha.

alpha-CE\

Adding these weights really helped deal with the category imbalance, Focal Loss’s paper reports:

When the imbalance between classes is larger, the loss of cross entropy is affected during training. The classification errors of easily classified samples accounted for most of the total loss and dominated the gradient. Although alpha balances the importance of positive/negative examples, it does not distinguish between easy/hard examples.

What the author wants to explain is:

Although we add alpha, it does give different weights to different categories to balance the importance of positive and negative samples, but in most cases, this alone is not enough. What we also need to do is reduce the loss of classification errors in samples that are easy to classify. Because otherwise, these easily sorted samples dominate our training.

Then how did Focal Loss deal with it? It added a multiplicative factor (1 − pt)**γ relative to the cross entropy, thus reducing the loss generated in the sample range of easy classification as we mentioned above.

Let’s see how Focal Loss is expressed. Is it clearer?

How do you do that in code?

This is the implementation of Focal Loss in Pytorch.

class WeightedFocalLoss(nn.Module) :
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=25., gamma=2) :
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha]).cuda()
        self.gamma = gamma


    def forward(self, inputs, targets) :
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()
Copy the code

If you understand what alpha and gamma mean, this implementation should make sense. Also, as mentioned in the article, we’re multiplying BCE by factors.

Credits

Post the author’s Twitter, of course, if you have any questions to discuss, you can also leave a message on the official account.

  • fig-1 and fig-2 are from the Fastai 2018 course Lecture-09!

To be continued

Thank you for your reading and support. We will continue to share what we think, think and learn with you. Hope you all have fun!

Highlights of past For beginners entry route of artificial intelligence and data download machine learning and deep learning notes such as printing machine learning online manual deep learning notes album "statistical learning method" code retrieval based album album download AI based machine learning math to get a sale station knowledge star coupons, copy the link directly open: HTTPS://t.zsxq.com/yFQV7am site QQ group 1003271085. To join the wechat group, please scan the code to enter the group:
Copy the code