|
import streamlit as st |
|
import torch |
|
import numpy as np |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def Model(): |
|
from transformers import DebertaTokenizer, DebertaForSequenceClassification |
|
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base") |
|
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8) |
|
bn_state_dict = torch.load('model_weights.pt', map_location=torch.device('cpu')) |
|
model.load_state_dict(bn_state_dict) |
|
return model, tokenizer |
|
|
|
def Predict(model, tokenizer, text): |
|
res = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512) |
|
res = model(**res) |
|
logits = res.logits.softmax(dim=1) |
|
logits = logits.detach().numpy()[0] |
|
return logits |
|
|
|
def Print(logits, dictionary): |
|
z = zip(logits, np.arange(0, 8)) |
|
z = sorted(z, key=lambda x: x[0], reverse=True) |
|
summ, idx = 0, 0 |
|
while summ < 0.95: |
|
string = dictionary[z[idx][1]] |
|
st.markdown(f"{idx + 1}. {string}") |
|
summ += z[idx][0] |
|
idx += 1 |
|
|
|
def filter(title, abstract): |
|
if len(title) < 10 or (len(abstract) > 0 and len(abstract) < 20): |
|
st.markdown("Хммм... Вы точно не ошиблись? Слишком маленькое название или описание.") |
|
return False |
|
return True |
|
|
|
st.title('Классификация статьи по названию и описанию') |
|
|
|
|
|
title = st.text_area("Введите название статьи:") |
|
|
|
abstract = st.text_area("Введите описание статьи:") |
|
|
|
|
|
text = title + '. ' + abstract |
|
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science', |
|
'math', 'physics', 'quantitative biology', 'quantitative finance', |
|
'statistics'] |
|
if filter(title, abstract): |
|
model, tokenizer = Model() |
|
logits = Predict(model, tokenizer, text) |
|
st.header("Топ 95%:") |
|
Print(logits, dictionary) |
|
|
|
|
|
|