Spaces:
Runtime error
Runtime error
Update util.py
Browse files
util.py
CHANGED
|
@@ -3,9 +3,58 @@ import torch
|
|
| 3 |
from transformers import BertTokenizer, BertModel
|
| 4 |
from huggingface_hub import hf_hub_url, cached_download
|
| 5 |
|
| 6 |
-
def get_cls_layer():
|
| 7 |
-
config_file_url = hf_hub_url(
|
| 8 |
value = cached_download(config_file_url)
|
| 9 |
return torch.load(value, map_location=torch.device('cpu'))
|
| 10 |
|
| 11 |
-
get_cls_layer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from transformers import BertTokenizer, BertModel
|
| 4 |
from huggingface_hub import hf_hub_url, cached_download
|
| 5 |
|
| 6 |
+
def get_cls_layer(repo_id="furrutiav/beto_coherence"):
|
| 7 |
+
config_file_url = hf_hub_url(repo_id, filename="cls_layer.torch")
|
| 8 |
value = cached_download(config_file_url)
|
| 9 |
return torch.load(value, map_location=torch.device('cpu'))
|
| 10 |
|
| 11 |
+
cls_layer = get_cls_layer()
|
| 12 |
+
|
| 13 |
+
beto_model = BertModel.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a")
|
| 14 |
+
|
| 15 |
+
beto_tokenizer = BertTokenizer.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a", do_lower_case=False)
|
| 16 |
+
|
| 17 |
+
e = beto_model.eval()
|
| 18 |
+
|
| 19 |
+
def preproccesing(Q, A, maxlen=60):
|
| 20 |
+
Q = " ".join(str(Q).replace("\n", " ").split())
|
| 21 |
+
A = " ".join(str(A).replace("\n", " ").split())
|
| 22 |
+
Q = Q if Q != "" else "nan"
|
| 23 |
+
A = A if A != "" else "nan"
|
| 24 |
+
|
| 25 |
+
tokens1 = beto_tokenizer.tokenize(Q)
|
| 26 |
+
tokens1 = ['[CLS]'] + tokens1 + ['[SEP]']
|
| 27 |
+
if len(tokens1) < maxlen:
|
| 28 |
+
tokens1 = tokens1 + ['[PAD]' for _ in range(maxlen - len(tokens1))]
|
| 29 |
+
else:
|
| 30 |
+
tokens1 = tokens1[:maxlen-1] + ['[SEP]']
|
| 31 |
+
|
| 32 |
+
tokens2 = beto_tokenizer.tokenize(A)
|
| 33 |
+
tokens2 = tokens2 + ['[SEP]']
|
| 34 |
+
if len(tokens2) < maxlen:
|
| 35 |
+
tokens2 = tokens2 + ['[PAD]' for _ in range(maxlen - len(tokens2))]
|
| 36 |
+
else:
|
| 37 |
+
tokens2 = tokens2[:maxlen-1] + ['[SEP]']
|
| 38 |
+
|
| 39 |
+
tokens = tokens1+tokens2
|
| 40 |
+
tokens_ids = beto_tokenizer.convert_tokens_to_ids(tokens)
|
| 41 |
+
tokens_ids_tensor = torch.tensor(tokens_ids)
|
| 42 |
+
|
| 43 |
+
attn_mask = (tokens_ids_tensor != 1).long()
|
| 44 |
+
return tokens_ids_tensor, attn_mask
|
| 45 |
+
|
| 46 |
+
def C1Classifier(Q, A, is_probs=True):
|
| 47 |
+
tokens_ids_tensor, attn_mask = preproccesing(Q, A)
|
| 48 |
+
cont_reps = beto_model(tokens_ids_tensor.unsqueeze(0), attention_mask = attn_mask.unsqueeze(0))
|
| 49 |
+
cls_rep = cont_reps.last_hidden_state[:, 0]
|
| 50 |
+
logits = cls_layer(cls_rep)
|
| 51 |
+
probs = torch.sigmoid(logits)
|
| 52 |
+
soft_probs = probs.argmax(1)
|
| 53 |
+
if is_probs:
|
| 54 |
+
return probs.detach().numpy()[0]
|
| 55 |
+
else:
|
| 56 |
+
return soft_probs.numpy()[0]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|