ErfanMoosaviMonazzah
commited on
Commit
•
095611c
1
Parent(s):
189fb58
Upload model
Browse files
modeling_backpack_gpt2_nli.py
CHANGED
@@ -57,7 +57,7 @@ class BackpackGPT2NLIModel(GPT2PreTrainedModel):
|
|
57 |
|
58 |
def predict(self, input_ids=None, attention_mask=None):
|
59 |
logits = self.forward(input_ids, attention_mask, labels=None)
|
60 |
-
p = torch.argmax(
|
61 |
labels = [self.config.id2label[index] for index in p]
|
62 |
return labels
|
63 |
|
|
|
57 |
|
58 |
def predict(self, input_ids=None, attention_mask=None):
|
59 |
logits = self.forward(input_ids, attention_mask, labels=None)
|
60 |
+
p = torch.argmax(logits, axis=1)
|
61 |
labels = [self.config.id2label[index] for index in p]
|
62 |
return labels
|
63 |
|