ysda_hf / app.py
ppapenj's picture
Create app.py
3d4f52d verified
raw
history blame
1.57 kB
import streamlit as st
from transformers import pipeline
import streamlit as st
import json
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
@st.cache_resource
def load_dicts():
with open("label2ind.json", "r") as file:
label2ind = json.load(file)
with open("ind2label.json", "r") as file:
ind2label = json.load(file)
return label2ind, ind2label
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = AutoModelForSequenceClassification.from_pretrained(
"my_model/checkpoint-23000",
num_labels=len(label2ind),
problem_type="single_label_classification",
)
return tokenizer, model
label2ind, ind2label = load_dicts()
tokenizer, model = load_model()
title = st.text_input("Title", value="Math")
abstract = st.text_input("Abstract", value="Random variable")
def get_logits(title, abstract):
text = title + "###" + abstract
logits = model(tokenizer(text, return_tensors="pt")['input_ids'])['logits']
return logits
def get_ans(logits):
ind = torch.argsort(logits, dim=1, descending=True)
logits = F.softmax(logits)
cum_sum = 0
i = 0
while cum_sum < 0.95:
cum_sum += logits[0][ind[i]]
st.write(f"label {ind2label[ind[i]]} with probability {logits[0][ind[i]] * 100}%")
i +=1
if title or abstract:
st.write(query)
st.write(result)
logits = get_logits(text, abstract)