import torch from torch import nn import torch.nn.functional as F class CustomLinear(nn.Linear): def __init__(self, *args, init_eye_val=0.0, is_diagonal=False, **kwargs): super().__init__(*args, **kwargs) self.init_eye_val = init_eye_val class CustomLinearInitialized(nn.Linear): def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, init_fun=None) -> None: super().__init__(in_features, out_features, bias, device, dtype) self.init_fun = init_fun class CustomDiagonalLinear(nn.Module): def __init__(self, d_model, bias=True, init_eye_val=0.0): super().__init__() self.init_eye_val = init_eye_val self.weight = nn.Parameter(torch.full((d_model,), init_eye_val)) self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None def forward(self, input): out = input * self.weight if self.bias is not None: out += self.bias return out class Gate(nn.Module): def __init__(self, items, init_val=0.0): super().__init__() self.init_val = init_val self.gate = nn.Parameter(torch.full((items,), init_val)) def forward(self, input, dim): if input.ndim != 4: raise ValueError('input must be a 4D tensor') shape = [1] * 4 shape[dim] = -1 return input * self.gate.view(*shape) class AttentivePoolingClassifier(nn.Module): def __init__(self, d_model, num_classes, hidden_dim=128): """ Attentive Pooling Classifier Args: d_model: Input feature dimension (D) num_classes: Number of output classes (V) hidden_dim: Hidden dimension for attention mechanism """ super(AttentivePoolingClassifier, self).__init__() # Attention mechanism for pooling [B,T,D] -> [B,D] self.attention_projection = nn.Linear(d_model, hidden_dim) self.attention_weights = nn.Linear(hidden_dim, 1) # Classifier [B,D] -> [B,V] self.classifier = nn.Sequential( nn.Linear(d_model, hidden_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim, num_classes) ) def forward(self, x, apply_stop_gradient=True): """ Forward pass Args: x: Input tensor of shape [B, T, D] apply_stop_gradient: Whether to apply stop gradient Returns: logits: Output logits [B, V] attention_weights: Attention weights [B, T] pooled_features: Pooled features [B, D] """ # Apply stop gradient if specified if apply_stop_gradient: x = x.detach() # Compute attention weights # x: [B, T, D] -> [B, T, hidden_dim] att_proj = torch.tanh(self.attention_projection(x)) # att_proj: [B, T, hidden_dim] -> [B, T, 1] -> [B, T] attention_scores = self.attention_weights(att_proj).squeeze(-1) attention_weights = F.softmax(attention_scores, dim=-1) # Apply attentive pooling: [B, T, D] * [B, T, 1] -> [B, D] pooled_features = torch.sum(x * attention_weights.unsqueeze(-1), dim=1) # Classification logits = self.classifier(pooled_features) return logits