yeswanthvarma commited on
Commit
7aeb013
·
verified ·
1 Parent(s): 1bc09e1

Update utils/xlnet_model.py

Browse files
Files changed (1) hide show
  1. utils/xlnet_model.py +3 -2
utils/xlnet_model.py CHANGED
@@ -24,13 +24,14 @@ class XLNetAnswerAssessmentModel(nn.Module):
24
  hidden = 768
25
  self.fc1 = nn.Linear(hidden, 256)
26
  self.fc2 = nn.Linear(256, 64)
27
- self.out = nn.Linear(64, 1)
28
 
29
  def forward(self, input_ids, attention_mask=None):
30
  pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(1)
31
  x = torch.relu(self.fc1(pooled))
32
  x = torch.relu(self.fc2(x))
33
- return torch.sigmoid(self.out(x))
 
34
 
35
  # Initialize model and tokenizer
36
  xlnet_available = False
 
24
  hidden = 768
25
  self.fc1 = nn.Linear(hidden, 256)
26
  self.fc2 = nn.Linear(256, 64)
27
+ self.output = nn.Linear(64, 1) # ← Change from `self.out` to `self.output`
28
 
29
  def forward(self, input_ids, attention_mask=None):
30
  pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(1)
31
  x = torch.relu(self.fc1(pooled))
32
  x = torch.relu(self.fc2(x))
33
+ return torch.sigmoid(self.output(x)) # ← And change here too
34
+
35
 
36
  # Initialize model and tokenizer
37
  xlnet_available = False