Lakoc commited on
Commit
1f6e305
·
verified ·
1 Parent(s): 9151753

Update layers.py

Browse files
Files changed (1) hide show
  1. layers.py +0 -58
layers.py CHANGED
@@ -39,61 +39,3 @@ class Gate(nn.Module):
39
  shape = [1] * 4
40
  shape[dim] = -1
41
  return input * self.gate.view(*shape)
42
-
43
-
44
- class AttentivePoolingClassifier(nn.Module):
45
- def __init__(self, d_model, num_classes, hidden_dim=128):
46
- """
47
- Attentive Pooling Classifier
48
-
49
- Args:
50
- d_model: Input feature dimension (D)
51
- num_classes: Number of output classes (V)
52
- hidden_dim: Hidden dimension for attention mechanism
53
- """
54
- super(AttentivePoolingClassifier, self).__init__()
55
-
56
- # Attention mechanism for pooling [B,T,D] -> [B,D]
57
- self.attention_projection = nn.Linear(d_model, hidden_dim)
58
- self.attention_weights = nn.Linear(hidden_dim, 1)
59
-
60
- # Classifier [B,D] -> [B,V]
61
- self.classifier = nn.Sequential(
62
- nn.Linear(d_model, hidden_dim),
63
- nn.ReLU(),
64
- nn.Dropout(0.1),
65
- nn.Linear(hidden_dim, num_classes)
66
- )
67
-
68
- def forward(self, x, apply_stop_gradient=True):
69
- """
70
- Forward pass
71
-
72
- Args:
73
- x: Input tensor of shape [B, T, D]
74
- apply_stop_gradient: Whether to apply stop gradient
75
-
76
- Returns:
77
- logits: Output logits [B, V]
78
- attention_weights: Attention weights [B, T]
79
- pooled_features: Pooled features [B, D]
80
- """
81
- # Apply stop gradient if specified
82
- if apply_stop_gradient:
83
- x = x.detach()
84
-
85
- # Compute attention weights
86
- # x: [B, T, D] -> [B, T, hidden_dim]
87
- att_proj = torch.tanh(self.attention_projection(x))
88
-
89
- # att_proj: [B, T, hidden_dim] -> [B, T, 1] -> [B, T]
90
- attention_scores = self.attention_weights(att_proj).squeeze(-1)
91
- attention_weights = F.softmax(attention_scores, dim=-1)
92
-
93
- # Apply attentive pooling: [B, T, D] * [B, T, 1] -> [B, D]
94
- pooled_features = torch.sum(x * attention_weights.unsqueeze(-1), dim=1)
95
-
96
- # Classification
97
- logits = self.classifier(pooled_features)
98
-
99
- return logits
 
39
  shape = [1] * 4
40
  shape[dim] = -1
41
  return input * self.gate.view(*shape)