focal loss

The numbers of examples of different classes for training networks can be different, which can make the network inconsistent in its reliability for identification of the different classes. Increasing statistics can be one strategy. Another strategy can be to modify the loss function for training a network to apply a greater focus to scarce, hard examples. This is called focal loss.

example of focal loss with Keras with TensorFlow backend


In [1]:
from keras import backend as K
import tensorflow as tf

def focal_loss(gamma=2., alpha=.25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.sum(alpha*K.pow(1.-pt_1, gamma)*K.log(pt_1))-K.sum((1-alpha)*K.pow(pt_0, gamma)*K.log(1.-pt_0))
    return focal_loss_fixed


Using TensorFlow backend.

In [2]:
model.compile(optimizer=optimizer, loss=[focal_loss(alpha=.25, gamma=2)])