Spaces:
Build error
Build error
| from typing import Any | |
| from pytorch_toolbelt.losses import BinaryFocalLoss | |
| from torch import nn | |
| from torch.nn.modules.loss import BCEWithLogitsLoss | |
| class WeightedLosses(nn.Module): | |
| def __init__(self, losses, weights): | |
| super().__init__() | |
| self.losses = losses | |
| self.weights = weights | |
| def forward(self, *input: Any, **kwargs: Any): | |
| cum_loss = 0 | |
| for loss, w in zip(self.losses, self.weights): | |
| cum_loss += w * loss.forward(*input, **kwargs) | |
| return cum_loss | |
| class BinaryCrossentropy(BCEWithLogitsLoss): | |
| pass | |
| class FocalLoss(BinaryFocalLoss): | |
| def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False, | |
| reduced_threshold=None): | |
| super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold) |