import streamlit as st import torch import time import pandas as pd from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast @st.cache_resource def load_model_and_tokenizer(): model = DistilBertForSequenceClassification.from_pretrained( "./results/checkpoint-64800" ) tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased") model.eval() return model, tokenizer model, tokenizer = load_model_and_tokenizer() def classify_text(text: str) -> (str, float, dict): """Токенизация текста, запуск инференса, измерение времени выполнения и визуализация вероятностей.""" encoding = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=128 ) start_time = time.time() with torch.no_grad(): outputs = model(**encoding) logits = outputs.logits elapsed_time = time.time() - start_time probabilities_tensor = torch.softmax(logits, dim=1).squeeze() probabilities = probabilities_tensor.tolist() predicted_class_id = torch.argmax(logits, dim=1).item() id2label = model.config.id2label # Предполагается, что id2label задан при обучении. predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id) if id2label: probabilities_dict = { id2label[i]: probabilities[i] for i in range(len(probabilities)) } else: probabilities_dict = { str(i): probabilities[i] for i in range(len(probabilities)) } return predicted_label, elapsed_time, probabilities_dict st.markdown( """ """, unsafe_allow_html=True ) st.markdown("

Классификация тэгов в статьях архива

", unsafe_allow_html=True) st.markdown("

Дизайн почти как от студии Артемия Лебедева

", unsafe_allow_html=True) st.write("Это демо-приложение позволяет классифицировать текст статьи или абстракта с помощью модели DistilBERT. Был использован весь архив от 2016 года.") st.sidebar.title("Опции") mode = st.sidebar.radio("Выберите режим ввода:", ["Ручной ввод", "Загрузка файла", "Пример"]) st.sidebar.markdown("---") st.sidebar.info( "Вы можете ввести текст вручную, загрузить файл (.txt) или использовать готовый пример." ) st.sidebar.markdown("---") if st.sidebar.button("Очистить кэш"): st.cache_resource.clear() # Сброс кэша (при перезапуске приложения) st.sidebar.success("Кэш очищен!") user_text = "" if mode == "Ручной ввод": title_input = st.text_area("Название (можно пропустить)", "", height=100) abstract_input = st.text_area("Абстракт (тоже не обязателен)", "", height=150) user_text = title_input.strip() + " " + abstract_input.strip() if st.button("Классифицировать", key="manual"): if user_text.strip() == "": st.error("Пожалуйста, введите или загрузите текст для классификации.") else: with st.spinner("Модель обрабатывает текст..."): predicted_label, inference_time, probabilities_dict = classify_text(user_text) st.success(f"Предсказанный тэг: **{predicted_label}**") st.info(f"Время инференса: {inference_time:.2f} секунд") st.markdown("### Распределение вероятностей") chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс") st.bar_chart(chart_data) elif mode == "Загрузка файла": uploaded_file = st.file_uploader("Загрузите текстовый файл (.txt)", type=["txt"]) if uploaded_file is not None: file_text = uploaded_file.read().decode("utf-8") st.text_area("Содержимое файла", file_text, height=150) user_text = file_text if st.button("Классифицировать", key="file"): if user_text.strip() == "": st.error("Пожалуйста, введите или загрузите текст для классификации.") else: with st.spinner("Модель обрабатывает текст..."): predicted_label, inference_time, probabilities_dict = classify_text(user_text) st.success(f"Предсказанный тэг: **{predicted_label}**") st.info(f"Время инференса: {inference_time:.2f} секунд") st.markdown("### Распределение вероятностей") chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс") st.bar_chart(chart_data) elif mode == "Пример": sample_text = ( "Non-invertible symmetries of two-dimensional Non-Linear Sigma Models" ) st.text_area("Пример текста", sample_text, height=150) if st.button("Использовать пример", key="sample"): user_text = sample_text with st.spinner("Модель обрабатывает текст..."): predicted_label, inference_time, probabilities_dict = classify_text(user_text) st.success(f"Предсказанный тэг: **{predicted_label}**") st.info(f"Время инференса: {inference_time:.2f} секунд") st.markdown("### Распределение вероятностей") chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс") st.bar_chart(chart_data) st.markdown("---") st.markdown("### О модели") st.write( "Модель использует **DistilBERT** для классификации текста. " "Данный демо-проект вдохновлён эстетикой студии Артемия Лебедева." ) st.markdown("### Об авторе: физик-теоретик") st.write( "Этот проект был создан для 4 лабораторки." ) st.sidebar.markdown("### Контакты") st.sidebar.write("Есть вопросы или предложения? Пишите на fedor.popov@phystech.edu")