|
import math |
|
from torch import nn |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
class SentimentWeightedLoss(nn.Module): |
|
"""BCEWithLogits + dynamic weighting. |
|
|
|
We weight each sample by: |
|
• length_weight: sqrt(num_tokens) / sqrt(max_tokens) |
|
• confidence_weight: |sigmoid(logits) - 0.5| (higher confidence ⇒ larger weight) |
|
|
|
The two weights are combined multiplicatively then normalized. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.bce = nn.BCEWithLogitsLoss(reduction="none") |
|
self.min_len_weight_sqrt = 0.1 |
|
|
|
def forward(self, logits, targets, lengths): |
|
base_loss = self.bce(logits.view(-1), targets.float()) |
|
|
|
prob = torch.sigmoid(logits.view(-1)) |
|
confidence_weight = (prob - 0.5).abs() * 2 |
|
|
|
if lengths.numel() == 0: |
|
|
|
|
|
return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad) |
|
|
|
length_weight = torch.sqrt(lengths.float()) / math.sqrt(lengths.max().item()) |
|
length_weight = length_weight.clamp(self.min_len_weight_sqrt, 1.0) |
|
|
|
weights = confidence_weight * length_weight |
|
weights = weights / (weights.mean() + 1e-8) |
|
return (base_loss * weights).mean() |
|
|
|
|
|
|
|
|
|
class SentimentFocalLoss(nn.Module): |
|
""" |
|
This loss function incorporates: |
|
1. Base BCEWithLogitsLoss. |
|
2. Label Smoothing. |
|
3. Focal Loss modulation to focus more on hard examples (can be reversed to focus on easy examples). |
|
4. Sample weighting based on review length. |
|
5. Sample weighting based on prediction confidence. |
|
|
|
The final loss for each sample is calculated roughly as: |
|
Loss_sample = FocalModulator(pt, gamma) * BCE(logits, smoothed_targets) * NormalizedExternalWeight |
|
NormalizedExternalWeight = (ConfidenceWeight * LengthWeight) / Mean(ConfidenceWeight * LengthWeight) |
|
""" |
|
|
|
def __init__(self, gamma_focal: float = 0.1, label_smoothing_epsilon: float = 0.05): |
|
""" |
|
Args: |
|
gamma_focal (float): Gamma parameter for Focal Loss. |
|
- If gamma_focal > 0 (e.g., 2.0), applies standard Focal Loss, |
|
down-weighting easy examples (focus on hard examples). |
|
- If gamma_focal < 0 (e.g., -2.0), applies a reversed Focal Loss, |
|
down-weighting hard examples (focus on easy examples by up-weighting pt). |
|
- If gamma_focal = 0, no Focal Loss modulation is applied. |
|
label_smoothing_epsilon (float): Epsilon for label smoothing. (0.0 <= epsilon < 1.0) |
|
- If 0.0, no label smoothing is applied. Converts hard labels (0, 1) |
|
to soft labels (epsilon, 1-epsilon). |
|
""" |
|
super().__init__() |
|
if not (0.0 <= label_smoothing_epsilon < 1.0): |
|
raise ValueError("label_smoothing_epsilon must be between 0.0 and <1.0.") |
|
|
|
self.gamma_focal = gamma_focal |
|
self.label_smoothing_epsilon = label_smoothing_epsilon |
|
|
|
self.bce_loss_no_reduction = nn.BCEWithLogitsLoss(reduction="none") |
|
|
|
def forward(self, logits: torch.Tensor, targets: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Computes the custom loss. |
|
|
|
Args: |
|
logits (torch.Tensor): Raw logits from the model. Expected shape [B] or [B, 1]. |
|
targets (torch.Tensor): Ground truth labels (0 or 1). Expected shape [B] or [B, 1]. |
|
lengths (torch.Tensor): Number of tokens in each review. Expected shape [B]. |
|
|
|
Returns: |
|
torch.Tensor: The computed scalar loss. |
|
""" |
|
B = logits.size(0) |
|
if B == 0: |
|
return torch.tensor(0.0, device=logits.device, requires_grad=True) |
|
|
|
logits_flat = logits.view(-1) |
|
original_targets_flat = targets.view(-1).float() |
|
|
|
|
|
if self.label_smoothing_epsilon > 0: |
|
|
|
targets_for_bce = original_targets_flat * (1.0 - self.label_smoothing_epsilon) + \ |
|
(1.0 - original_targets_flat) * self.label_smoothing_epsilon |
|
else: |
|
targets_for_bce = original_targets_flat |
|
|
|
|
|
base_bce_loss_terms = self.bce_loss_no_reduction(logits_flat, targets_for_bce) |
|
|
|
|
|
|
|
probs = torch.sigmoid(logits_flat) |
|
|
|
pt = torch.where(original_targets_flat.bool(), probs, 1.0 - probs) |
|
|
|
focal_modulator = torch.ones_like(pt) |
|
if self.gamma_focal > 0: |
|
focal_modulator = (1.0 - pt + 1e-8).pow(self.gamma_focal) |
|
elif self.gamma_focal < 0: |
|
focal_modulator = (pt + 1e-8).pow(abs(self.gamma_focal)) |
|
|
|
modulated_loss_terms = focal_modulator * base_bce_loss_terms |
|
|
|
|
|
|
|
confidence_w = (probs - 0.5).abs() * 2.0 |
|
|
|
|
|
lengths_flat = lengths.view(-1).float() |
|
max_len_in_batch = lengths_flat.max().item() |
|
|
|
if max_len_in_batch == 0: |
|
length_w = torch.ones_like(lengths_flat) |
|
else: |
|
|
|
length_w = torch.sqrt(lengths_flat) / (math.sqrt(max_len_in_batch) + 1e-8) |
|
length_w = torch.clamp(length_w, 0.0, 1.0) |
|
|
|
|
|
|
|
external_weights = confidence_w * length_w |
|
|
|
|
|
|
|
if external_weights.sum() > 1e-8: |
|
normalized_external_weights = external_weights / (external_weights.mean() + 1e-8) |
|
else: |
|
normalized_external_weights = torch.ones_like(external_weights) |
|
|
|
|
|
final_loss_terms_per_sample = modulated_loss_terms * normalized_external_weights |
|
|
|
|
|
loss = final_loss_terms_per_sample.mean() |
|
|
|
return loss |
|
|