Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import time | |
import pandas as pd | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast | |
def load_model_and_tokenizer(): | |
model = DistilBertForSequenceClassification.from_pretrained( | |
"./results/checkpoint-64800" | |
) | |
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased") | |
model.eval() | |
return model, tokenizer | |
model, tokenizer = load_model_and_tokenizer() | |
def classify_text(text: str) -> (str, float, dict): | |
"""Токенизация текста, запуск инференса, измерение времени выполнения и визуализация вероятностей.""" | |
encoding = tokenizer( | |
text, return_tensors="pt", padding=True, truncation=True, max_length=128 | |
) | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = model(**encoding) | |
logits = outputs.logits | |
elapsed_time = time.time() - start_time | |
probabilities_tensor = torch.softmax(logits, dim=1).squeeze() | |
probabilities = probabilities_tensor.tolist() | |
predicted_class_id = torch.argmax(logits, dim=1).item() | |
id2label = model.config.id2label # Предполагается, что id2label задан при обучении. | |
predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id) | |
if id2label: | |
probabilities_dict = { id2label[i]: probabilities[i] for i in range(len(probabilities)) } | |
else: | |
probabilities_dict = { str(i): probabilities[i] for i in range(len(probabilities)) } | |
return predicted_label, elapsed_time, probabilities_dict | |
st.markdown( | |
""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap'); | |
body { | |
background-color: #1a1a1a; | |
color: #f5f5f5; | |
font-family: 'Montserrat', sans-serif; | |
} | |
h1 { | |
text-align: center; | |
font-size: 3em; | |
font-weight: 700; | |
margin-bottom: 0.2em; | |
} | |
h2 { | |
text-align: center; | |
font-size: 1.5em; | |
font-weight: 400; | |
margin-bottom: 1em; | |
} | |
.stButton>button { | |
background-color: #ff4500; | |
color: white; | |
font-size: 18px; | |
font-weight: bold; | |
border: none; | |
border-radius: 5px; | |
padding: 10px 20px; | |
} | |
.stTextArea>div>div>textarea { | |
background-color: #333; | |
color: #f5f5f5; | |
border: 2px solid #ff4500; | |
border-radius: 5px; | |
padding: 10px; | |
} | |
</style> | |
""", unsafe_allow_html=True | |
) | |
st.markdown("<h1>Классификация тэгов в статьях архива</h1>", unsafe_allow_html=True) | |
st.markdown("<h2>Дизайн почти как от студии Артемия Лебедева</h2>", unsafe_allow_html=True) | |
st.write("Это демо-приложение позволяет классифицировать текст статьи или абстракта с помощью модели DistilBERT. Был использован весь архив от 2016 года.") | |
st.sidebar.title("Опции") | |
mode = st.sidebar.radio("Выберите режим ввода:", ["Ручной ввод", "Загрузка файла", "Пример"]) | |
st.sidebar.markdown("---") | |
st.sidebar.info( | |
"Вы можете ввести текст вручную, загрузить файл (.txt) или использовать готовый пример." | |
) | |
st.sidebar.markdown("---") | |
if st.sidebar.button("Очистить кэш"): | |
st.cache_resource.clear() # Сброс кэша (при перезапуске приложения) | |
st.sidebar.success("Кэш очищен!") | |
user_text = "" | |
if mode == "Ручной ввод": | |
title_input = st.text_area("Название (можно пропустить)", "", height=100) | |
abstract_input = st.text_area("Абстракт (тоже не обязателен)", "", height=150) | |
user_text = title_input.strip() + " " + abstract_input.strip() | |
if st.button("Классифицировать", key="manual"): | |
if user_text.strip() == "": | |
st.error("Пожалуйста, введите или загрузите текст для классификации.") | |
else: | |
with st.spinner("Модель обрабатывает текст..."): | |
predicted_label, inference_time, probabilities_dict = classify_text(user_text) | |
st.success(f"Предсказанный тэг: **{predicted_label}**") | |
st.info(f"Время инференса: {inference_time:.2f} секунд") | |
st.markdown("### Распределение вероятностей") | |
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс") | |
st.bar_chart(chart_data) | |
elif mode == "Загрузка файла": | |
uploaded_file = st.file_uploader("Загрузите текстовый файл (.txt)", type=["txt"]) | |
if uploaded_file is not None: | |
file_text = uploaded_file.read().decode("utf-8") | |
st.text_area("Содержимое файла", file_text, height=150) | |
user_text = file_text | |
if st.button("Классифицировать", key="file"): | |
if user_text.strip() == "": | |
st.error("Пожалуйста, введите или загрузите текст для классификации.") | |
else: | |
with st.spinner("Модель обрабатывает текст..."): | |
predicted_label, inference_time, probabilities_dict = classify_text(user_text) | |
st.success(f"Предсказанный тэг: **{predicted_label}**") | |
st.info(f"Время инференса: {inference_time:.2f} секунд") | |
st.markdown("### Распределение вероятностей") | |
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс") | |
st.bar_chart(chart_data) | |
elif mode == "Пример": | |
sample_text = ( | |
"Non-invertible symmetries of two-dimensional Non-Linear Sigma Models" | |
) | |
st.text_area("Пример текста", sample_text, height=150) | |
if st.button("Использовать пример", key="sample"): | |
user_text = sample_text | |
with st.spinner("Модель обрабатывает текст..."): | |
predicted_label, inference_time, probabilities_dict = classify_text(user_text) | |
st.success(f"Предсказанный тэг: **{predicted_label}**") | |
st.info(f"Время инференса: {inference_time:.2f} секунд") | |
st.markdown("### Распределение вероятностей") | |
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс") | |
st.bar_chart(chart_data) | |
st.markdown("---") | |
st.markdown("### О модели") | |
st.write( | |
"Модель использует **DistilBERT** для классификации текста. " | |
"Данный демо-проект вдохновлён эстетикой студии Артемия Лебедева." | |
) | |
st.markdown("### Об авторе: физик-теоретик") | |
st.write( | |
"Этот проект был создан для 4 лабораторки." | |
) | |
st.sidebar.markdown("### Контакты") | |
st.sidebar.write("Есть вопросы или предложения? Пишите на [email protected]") | |