Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -36,29 +36,7 @@ def get_top95(y_predict, convert_target):
|
|
| 36 |
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.
|
| 37 |
from transformers import DistilBertModel, DistilBertTokenizer
|
| 38 |
|
| 39 |
-
class DistillBERTClass(torch.nn.Module):
|
| 40 |
-
def __init__(self):
|
| 41 |
-
super(DistillBERTClass, self).__init__()
|
| 42 |
-
self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
| 43 |
-
self.pre_classifier = torch.nn.Linear(768, 768)
|
| 44 |
-
self.dropout = torch.nn.Dropout(0.3)
|
| 45 |
-
self.classifier = torch.nn.Linear(768, 8)
|
| 46 |
|
| 47 |
-
def forward(self, input_ids, attention_mask):
|
| 48 |
-
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
|
| 49 |
-
hidden_state = output_1[0]
|
| 50 |
-
pooler = hidden_state[:, 0]
|
| 51 |
-
pooler = self.pre_classifier(pooler)
|
| 52 |
-
pooler = torch.nn.ReLU()(pooler)
|
| 53 |
-
pooler = self.dropout(pooler)
|
| 54 |
-
output = self.classifier(pooler)
|
| 55 |
-
return output
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
model = DistillBERTClass()
|
| 59 |
-
LEARNING_RATE = 1e-05
|
| 60 |
-
|
| 61 |
-
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
|
| 62 |
model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device('cpu'))
|
| 63 |
# model.load_state_dict(checkpoint['model'])
|
| 64 |
# optimizer.load_state_dict(checkpoint['opt'])
|
|
@@ -90,6 +68,7 @@ def get_predict(title, abstract):
|
|
| 90 |
attention_mask=inputs['attention_mask'],
|
| 91 |
)
|
| 92 |
logits = outputs[0]
|
|
|
|
| 93 |
y_predict = torch.nn.functional.softmax(logits).cpu().detach().numpy()
|
| 94 |
file_path = "sample.json"
|
| 95 |
|
|
|
|
| 36 |
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.
|
| 37 |
from transformers import DistilBertModel, DistilBertTokenizer
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
model = torch.load("pytorch_distilbert_news (3).bin", map_location=torch.device('cpu'))
|
| 41 |
# model.load_state_dict(checkpoint['model'])
|
| 42 |
# optimizer.load_state_dict(checkpoint['opt'])
|
|
|
|
| 68 |
attention_mask=inputs['attention_mask'],
|
| 69 |
)
|
| 70 |
logits = outputs[0]
|
| 71 |
+
print(logits)
|
| 72 |
y_predict = torch.nn.functional.softmax(logits).cpu().detach().numpy()
|
| 73 |
file_path = "sample.json"
|
| 74 |
|