Update utils/xlnet_model.py
Browse files- utils/xlnet_model.py +3 -2
utils/xlnet_model.py
CHANGED
@@ -24,13 +24,14 @@ class XLNetAnswerAssessmentModel(nn.Module):
|
|
24 |
hidden = 768
|
25 |
self.fc1 = nn.Linear(hidden, 256)
|
26 |
self.fc2 = nn.Linear(256, 64)
|
27 |
-
self.
|
28 |
|
29 |
def forward(self, input_ids, attention_mask=None):
|
30 |
pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(1)
|
31 |
x = torch.relu(self.fc1(pooled))
|
32 |
x = torch.relu(self.fc2(x))
|
33 |
-
return torch.sigmoid(self.
|
|
|
34 |
|
35 |
# Initialize model and tokenizer
|
36 |
xlnet_available = False
|
|
|
24 |
hidden = 768
|
25 |
self.fc1 = nn.Linear(hidden, 256)
|
26 |
self.fc2 = nn.Linear(256, 64)
|
27 |
+
self.output = nn.Linear(64, 1) # ← Change from `self.out` to `self.output`
|
28 |
|
29 |
def forward(self, input_ids, attention_mask=None):
|
30 |
pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(1)
|
31 |
x = torch.relu(self.fc1(pooled))
|
32 |
x = torch.relu(self.fc2(x))
|
33 |
+
return torch.sigmoid(self.output(x)) # ← And change here too
|
34 |
+
|
35 |
|
36 |
# Initialize model and tokenizer
|
37 |
xlnet_available = False
|