article_tagger / app.py
minority169's picture
Update app.py
8f704b9 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
@st.cache_resource
def load_model():
model_path = "./models"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
return model, tokenizer
def get_top_k_predictions(probs, k=0.95):
sorted_indices = np.argsort(probs)[::-1]
cumulative_prob = 0
selected_indices = []
for idx in sorted_indices:
cumulative_prob += probs[idx]
selected_indices.append(idx)
if cumulative_prob >= k:
break
return selected_indices
def main():
st.title("Классификатор научных статей")
st.write("""
Введите название статьи и её аннотацию, чтобы определить тематику.
Если аннотация не указана, классификация будет выполнена только по названию.
""")
model, tokenizer = load_model()
title = st.text_input("Название статьи:", "")
abstract = st.text_area("Аннотация:", "")
if st.button("Классифицировать"):
if not title and not abstract:
st.error("Нужно хотя бы название статьи")
return
text = title
if abstract:
text += " " + abstract
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].numpy()
top_indices = get_top_k_predictions(probs)
st.subheader("Результаты классификации:")
for idx in top_indices:
label = model.config.id2label[idx]
prob = probs[idx] * 100
st.write(f"- {label}: {prob:.2f}%")
if __name__ == "__main__":
main()