Spaces:
Sleeping
Sleeping
import re | |
import torch.nn as nn | |
class BaseObject(nn.Module): | |
def __init__(self, name=None): | |
super().__init__() | |
self._name = name | |
def __name__(self): | |
if self._name is None: | |
name = self.__class__.__name__ | |
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) | |
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() | |
else: | |
return self._name | |
class Metric(BaseObject): | |
pass | |
class Loss(BaseObject): | |
def __add__(self, other): | |
if isinstance(other, Loss): | |
return SumOfLosses(self, other) | |
else: | |
raise ValueError("Loss should be inherited from `Loss` class") | |
def __radd__(self, other): | |
return self.__add__(other) | |
def __mul__(self, value): | |
if isinstance(value, (int, float)): | |
return MultipliedLoss(self, value) | |
else: | |
raise ValueError("Loss should be inherited from `BaseLoss` class") | |
def __rmul__(self, other): | |
return self.__mul__(other) | |
class SumOfLosses(Loss): | |
def __init__(self, l1, l2): | |
name = "{} + {}".format(l1.__name__, l2.__name__) | |
super().__init__(name=name) | |
self.l1 = l1 | |
self.l2 = l2 | |
def __call__(self, *inputs): | |
return self.l1.forward(*inputs) + self.l2.forward(*inputs) | |
class MultipliedLoss(Loss): | |
def __init__(self, loss, multiplier): | |
# resolve name | |
if len(loss.__name__.split("+")) > 1: | |
name = "{} * ({})".format(multiplier, loss.__name__) | |
else: | |
name = "{} * {}".format(multiplier, loss.__name__) | |
super().__init__(name=name) | |
self.loss = loss | |
self.multiplier = multiplier | |
def __call__(self, *inputs): | |
return self.multiplier * self.loss.forward(*inputs) | |