class WeightedFocalLoss(nn.Module):
"Non weighted version of Focal Loss"
def __init__(self, weights, gamma=1.1):
super().__init__()
self.weights = weights
self.gamma = gamma
def forward(self, inputs, targets):
= inputs.squeeze()
inputs = targets.squeeze()
targets
= F.cross_entropy(inputs, targets, reduction='none')
BCE_loss = torch.exp(-BCE_loss)
pt = self.weights[targets]*(1-pt)**self.gamma * BCE_loss
F_loss
return F_loss.mean()
Focal Loss for Multi-class Classification
Loss Function
Extending normal Focal Loss