import streamlit as st
import torch
import time
import pandas as pd
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
@st.cache_resource
def load_model_and_tokenizer():
model = DistilBertForSequenceClassification.from_pretrained(
"./results/checkpoint-64800"
)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
def classify_text(text: str) -> (str, float, dict):
"""Токенизация текста, запуск инференса, измерение времени выполнения и визуализация вероятностей."""
encoding = tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=128
)
start_time = time.time()
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
elapsed_time = time.time() - start_time
probabilities_tensor = torch.softmax(logits, dim=1).squeeze()
probabilities = probabilities_tensor.tolist()
predicted_class_id = torch.argmax(logits, dim=1).item()
id2label = model.config.id2label # Предполагается, что id2label задан при обучении.
predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id)
if id2label:
probabilities_dict = { id2label[i]: probabilities[i] for i in range(len(probabilities)) }
else:
probabilities_dict = { str(i): probabilities[i] for i in range(len(probabilities)) }
return predicted_label, elapsed_time, probabilities_dict
st.markdown(
"""
""", unsafe_allow_html=True
)
st.markdown("
Классификация тэгов в статьях архива
", unsafe_allow_html=True)
st.markdown("Дизайн почти как от студии Артемия Лебедева
", unsafe_allow_html=True)
st.write("Это демо-приложение позволяет классифицировать текст статьи или абстракта с помощью модели DistilBERT. Был использован весь архив от 2016 года.")
st.sidebar.title("Опции")
mode = st.sidebar.radio("Выберите режим ввода:", ["Ручной ввод", "Загрузка файла", "Пример"])
st.sidebar.markdown("---")
st.sidebar.info(
"Вы можете ввести текст вручную, загрузить файл (.txt) или использовать готовый пример."
)
st.sidebar.markdown("---")
if st.sidebar.button("Очистить кэш"):
st.cache_resource.clear() # Сброс кэша (при перезапуске приложения)
st.sidebar.success("Кэш очищен!")
user_text = ""
if mode == "Ручной ввод":
title_input = st.text_area("Название (можно пропустить)", "", height=100)
abstract_input = st.text_area("Абстракт (тоже не обязателен)", "", height=150)
user_text = title_input.strip() + " " + abstract_input.strip()
if st.button("Классифицировать", key="manual"):
if user_text.strip() == "":
st.error("Пожалуйста, введите или загрузите текст для классификации.")
else:
with st.spinner("Модель обрабатывает текст..."):
predicted_label, inference_time, probabilities_dict = classify_text(user_text)
st.success(f"Предсказанный тэг: **{predicted_label}**")
st.info(f"Время инференса: {inference_time:.2f} секунд")
st.markdown("### Распределение вероятностей")
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
st.bar_chart(chart_data)
elif mode == "Загрузка файла":
uploaded_file = st.file_uploader("Загрузите текстовый файл (.txt)", type=["txt"])
if uploaded_file is not None:
file_text = uploaded_file.read().decode("utf-8")
st.text_area("Содержимое файла", file_text, height=150)
user_text = file_text
if st.button("Классифицировать", key="file"):
if user_text.strip() == "":
st.error("Пожалуйста, введите или загрузите текст для классификации.")
else:
with st.spinner("Модель обрабатывает текст..."):
predicted_label, inference_time, probabilities_dict = classify_text(user_text)
st.success(f"Предсказанный тэг: **{predicted_label}**")
st.info(f"Время инференса: {inference_time:.2f} секунд")
st.markdown("### Распределение вероятностей")
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
st.bar_chart(chart_data)
elif mode == "Пример":
sample_text = (
"Non-invertible symmetries of two-dimensional Non-Linear Sigma Models"
)
st.text_area("Пример текста", sample_text, height=150)
if st.button("Использовать пример", key="sample"):
user_text = sample_text
with st.spinner("Модель обрабатывает текст..."):
predicted_label, inference_time, probabilities_dict = classify_text(user_text)
st.success(f"Предсказанный тэг: **{predicted_label}**")
st.info(f"Время инференса: {inference_time:.2f} секунд")
st.markdown("### Распределение вероятностей")
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
st.bar_chart(chart_data)
st.markdown("---")
st.markdown("### О модели")
st.write(
"Модель использует **DistilBERT** для классификации текста. "
"Данный демо-проект вдохновлён эстетикой студии Артемия Лебедева."
)
st.markdown("### Об авторе: физик-теоретик")
st.write(
"Этот проект был создан для 4 лабораторки."
)
st.sidebar.markdown("### Контакты")
st.sidebar.write("Есть вопросы или предложения? Пишите на fedor.popov@phystech.edu")