Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
-
import streamlit as st
|
4 |
import json
|
5 |
import torch
|
6 |
from torch.nn import functional as F
|
@@ -32,18 +31,20 @@ abstract = st.text_input("Abstract", value="Random variable")
|
|
32 |
|
33 |
def get_logits(title, abstract):
|
34 |
text = title + "###" + abstract
|
35 |
-
|
|
|
36 |
return logits
|
37 |
|
38 |
def get_ans(logits):
|
39 |
-
|
40 |
-
|
41 |
cum_sum = 0
|
42 |
i = 0
|
43 |
-
while cum_sum < 0.95:
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
47 |
|
48 |
if title or abstract:
|
49 |
logits = get_logits(title, abstract)
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
|
|
3 |
import json
|
4 |
import torch
|
5 |
from torch.nn import functional as F
|
|
|
31 |
|
32 |
def get_logits(title, abstract):
|
33 |
text = title + "###" + abstract
|
34 |
+
inputs = tokenizer(text, return_tensors="pt")
|
35 |
+
logits = model(**inputs)['logits']
|
36 |
return logits
|
37 |
|
38 |
def get_ans(logits):
|
39 |
+
logits = F.softmax(logits, dim=1)
|
40 |
+
ind = torch.argsort(logits, dim=1, descending=True).flatten()
|
41 |
cum_sum = 0
|
42 |
i = 0
|
43 |
+
while cum_sum < 0.95 and i < len(ind):
|
44 |
+
idx = ind[i].item()
|
45 |
+
cum_sum += logits[0][idx].item()
|
46 |
+
st.write(f"label: {ind2label.get(str(idx))} with probability: {logits[0][idx].item() * 100:.2f}%")
|
47 |
+
i += 1
|
48 |
|
49 |
if title or abstract:
|
50 |
logits = get_logits(title, abstract)
|