Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from transformers import pipeline
|
| 3 |
-
import streamlit as st
|
| 4 |
import json
|
| 5 |
import torch
|
| 6 |
from torch.nn import functional as F
|
|
@@ -32,18 +31,20 @@ abstract = st.text_input("Abstract", value="Random variable")
|
|
| 32 |
|
| 33 |
def get_logits(title, abstract):
|
| 34 |
text = title + "###" + abstract
|
| 35 |
-
|
|
|
|
| 36 |
return logits
|
| 37 |
|
| 38 |
def get_ans(logits):
|
| 39 |
-
|
| 40 |
-
|
| 41 |
cum_sum = 0
|
| 42 |
i = 0
|
| 43 |
-
while cum_sum < 0.95:
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
if title or abstract:
|
| 49 |
logits = get_logits(title, abstract)
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from transformers import pipeline
|
|
|
|
| 3 |
import json
|
| 4 |
import torch
|
| 5 |
from torch.nn import functional as F
|
|
|
|
| 31 |
|
| 32 |
def get_logits(title, abstract):
|
| 33 |
text = title + "###" + abstract
|
| 34 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 35 |
+
logits = model(**inputs)['logits']
|
| 36 |
return logits
|
| 37 |
|
| 38 |
def get_ans(logits):
|
| 39 |
+
logits = F.softmax(logits, dim=1)
|
| 40 |
+
ind = torch.argsort(logits, dim=1, descending=True).flatten()
|
| 41 |
cum_sum = 0
|
| 42 |
i = 0
|
| 43 |
+
while cum_sum < 0.95 and i < len(ind):
|
| 44 |
+
idx = ind[i].item()
|
| 45 |
+
cum_sum += logits[0][idx].item()
|
| 46 |
+
st.write(f"label: {ind2label.get(str(idx))} with probability: {logits[0][idx].item() * 100:.2f}%")
|
| 47 |
+
i += 1
|
| 48 |
|
| 49 |
if title or abstract:
|
| 50 |
logits = get_logits(title, abstract)
|