import torch.nn as nn import torch.nn.functional as F class KLDivLoss(nn.Module): """Kullback-Leibler Divergence Loss""" def __init__(self, reduction='batchmean'): super().__init__() self.reduction = reduction def forward(self, inputs, targets): return F.kl_div(inputs, targets, reduction=self.reduction)