ppapenj commited on
Commit
e526340
·
verified ·
1 Parent(s): 489c19d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- import streamlit as st
4
  import json
5
  import torch
6
  from torch.nn import functional as F
@@ -32,18 +31,20 @@ abstract = st.text_input("Abstract", value="Random variable")
32
 
33
  def get_logits(title, abstract):
34
  text = title + "###" + abstract
35
- logits = model(tokenizer(text, return_tensors="pt")['input_ids'])['logits']
 
36
  return logits
37
 
38
  def get_ans(logits):
39
- ind = torch.argsort(logits, dim=1, descending=True)
40
- logits = F.softmax(logits)
41
  cum_sum = 0
42
  i = 0
43
- while cum_sum < 0.95:
44
- cum_sum += logits[0][ind[i]]
45
- st.write(f"label {ind2label[ind[i]]} with probability {logits[0][ind[i]] * 100}%")
46
- i +=1
 
47
 
48
  if title or abstract:
49
  logits = get_logits(title, abstract)
 
1
  import streamlit as st
2
  from transformers import pipeline
 
3
  import json
4
  import torch
5
  from torch.nn import functional as F
 
31
 
32
  def get_logits(title, abstract):
33
  text = title + "###" + abstract
34
+ inputs = tokenizer(text, return_tensors="pt")
35
+ logits = model(**inputs)['logits']
36
  return logits
37
 
38
  def get_ans(logits):
39
+ logits = F.softmax(logits, dim=1)
40
+ ind = torch.argsort(logits, dim=1, descending=True).flatten()
41
  cum_sum = 0
42
  i = 0
43
+ while cum_sum < 0.95 and i < len(ind):
44
+ idx = ind[i].item()
45
+ cum_sum += logits[0][idx].item()
46
+ st.write(f"label: {ind2label.get(str(idx))} with probability: {logits[0][idx].item() * 100:.2f}%")
47
+ i += 1
48
 
49
  if title or abstract:
50
  logits = get_logits(title, abstract)