import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): """Focal Loss for imbalanced classification""" def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) loss = self.alpha * (1 - pt) ** self.gamma * ce_loss return loss.mean() if self.reduction == 'mean' else loss.sum()