class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, weights, gamma=1.1):
        super().__init__()
        self.weights = weights
        self.gamma = gamma
    def forward(self, inputs, targets):
        inputs = inputs.squeeze()
        targets = targets.squeeze()
        BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.weights[targets]*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()Focal Loss for Multi-class Classification
Loss Function
  
    Extending normal Focal Loss