# src/head.py import torch from torch import nn class SimilarityHead(nn.Module): def __init__(self, init_alpha: float = 10.0, init_beta: float = 0.0): super().__init__() # initialize in log-space for stability; store α as exp(log_alpha) self.log_alpha = nn.Parameter(torch.log(torch.tensor(init_alpha, dtype=torch.float32))) self.beta = nn.Parameter(torch.tensor(init_beta, dtype=torch.float32)) def forward(self, sim: torch.Tensor) -> torch.Tensor: # sim: [B] or [B, 1] alpha = torch.exp(self.log_alpha) return alpha * sim.view(-1) + self.beta