Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
|
|
|
|
3 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
|
4 |
|
5 |
-
|
6 |
@st.cache_resource
|
7 |
def load_model_and_tokenizer():
|
8 |
model = DistilBertForSequenceClassification.from_pretrained(
|
@@ -14,25 +16,32 @@ def load_model_and_tokenizer():
|
|
14 |
|
15 |
model, tokenizer = load_model_and_tokenizer()
|
16 |
|
17 |
-
def classify_text(text: str) -> str:
|
18 |
-
"""Токенизация
|
19 |
encoding = tokenizer(
|
20 |
text, return_tensors="pt", padding=True, truncation=True, max_length=128
|
21 |
)
|
|
|
22 |
with torch.no_grad():
|
23 |
outputs = model(**encoding)
|
24 |
logits = outputs.logits
|
|
|
|
|
|
|
25 |
predicted_class_id = torch.argmax(logits, dim=1).item()
|
26 |
-
id2label = model.config.id2label # Предполагается, что id2label
|
27 |
predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id)
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
# Добавление кастомного CSS для стилизации в духе Артемия Лебедева.
|
31 |
st.markdown(
|
32 |
"""
|
33 |
<style>
|
34 |
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap');
|
35 |
-
|
36 |
body {
|
37 |
background-color: #1a1a1a;
|
38 |
color: #f5f5f5;
|
@@ -70,19 +79,84 @@ st.markdown(
|
|
70 |
""", unsafe_allow_html=True
|
71 |
)
|
72 |
|
73 |
-
|
74 |
st.markdown("<h1>Классификация тэгов в статьях архива</h1>", unsafe_allow_html=True)
|
75 |
st.markdown("<h2>Дизайн почти как от студии Артемия Лебедева</h2>", unsafe_allow_html=True)
|
76 |
-
st.write("
|
77 |
|
78 |
-
# Текстовые поля для ввода названия и абстракта.
|
79 |
-
title = st.text_area("Название", "")
|
80 |
-
abstract = st.text_area("Абстракт", "")
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
st.success(f"Предсказанный тэг: **{predicted_label}**")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import time
|
4 |
+
import pandas as pd
|
5 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
|
6 |
|
7 |
+
|
8 |
@st.cache_resource
|
9 |
def load_model_and_tokenizer():
|
10 |
model = DistilBertForSequenceClassification.from_pretrained(
|
|
|
16 |
|
17 |
model, tokenizer = load_model_and_tokenizer()
|
18 |
|
19 |
+
def classify_text(text: str) -> (str, float, dict):
|
20 |
+
"""Токенизация текста, запуск инференса, измерение времени выполнения и визуализация вероятностей."""
|
21 |
encoding = tokenizer(
|
22 |
text, return_tensors="pt", padding=True, truncation=True, max_length=128
|
23 |
)
|
24 |
+
start_time = time.time()
|
25 |
with torch.no_grad():
|
26 |
outputs = model(**encoding)
|
27 |
logits = outputs.logits
|
28 |
+
elapsed_time = time.time() - start_time
|
29 |
+
probabilities_tensor = torch.softmax(logits, dim=1).squeeze()
|
30 |
+
probabilities = probabilities_tensor.tolist()
|
31 |
predicted_class_id = torch.argmax(logits, dim=1).item()
|
32 |
+
id2label = model.config.id2label # Предполагается, что id2label задан при обучении.
|
33 |
predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id)
|
34 |
+
if id2label:
|
35 |
+
probabilities_dict = { id2label[i]: probabilities[i] for i in range(len(probabilities)) }
|
36 |
+
else:
|
37 |
+
probabilities_dict = { str(i): probabilities[i] for i in range(len(probabilities)) }
|
38 |
+
return predicted_label, elapsed_time, probabilities_dict
|
39 |
+
|
40 |
|
|
|
41 |
st.markdown(
|
42 |
"""
|
43 |
<style>
|
44 |
@import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap');
|
|
|
45 |
body {
|
46 |
background-color: #1a1a1a;
|
47 |
color: #f5f5f5;
|
|
|
79 |
""", unsafe_allow_html=True
|
80 |
)
|
81 |
|
82 |
+
|
83 |
st.markdown("<h1>Классификация тэгов в статьях архива</h1>", unsafe_allow_html=True)
|
84 |
st.markdown("<h2>Дизайн почти как от студии Артемия Лебедева</h2>", unsafe_allow_html=True)
|
85 |
+
st.write("Это демо-приложение позволяет классифицировать текст статьи или абстракта с помощью модели DistilBERT. Был использован весь архив от 2016 года.")
|
86 |
|
|
|
|
|
|
|
87 |
|
88 |
+
st.sidebar.title("Опции")
|
89 |
+
mode = st.sidebar.radio("Выберите режим ввода:", ["Ручной ввод", "Загрузка файла", "Пример"])
|
90 |
+
st.sidebar.markdown("---")
|
91 |
+
st.sidebar.info(
|
92 |
+
"Вы можете ввести текст вручную, загрузить файл (.txt) или использовать готовый пример."
|
93 |
+
)
|
94 |
+
st.sidebar.markdown("---")
|
95 |
+
if st.sidebar.button("Очистить кэш"):
|
96 |
+
st.cache_resource.clear() # Сброс кэша (при перезапуске приложения)
|
97 |
+
st.sidebar.success("Кэш очищен!")
|
98 |
+
|
99 |
+
|
100 |
+
user_text = ""
|
101 |
+
if mode == "Ручной ввод":
|
102 |
+
title_input = st.text_area("Название (можно пропустить)", "", height=100)
|
103 |
+
abstract_input = st.text_area("Абстракт (тоже не обязателен)", "", height=150)
|
104 |
+
user_text = title_input.strip() + " " + abstract_input.strip()
|
105 |
+
if st.button("Классифицировать", key="manual"):
|
106 |
+
if user_text.strip() == "":
|
107 |
+
st.error("Пожалуйста, введите или загрузите текст для классификации.")
|
108 |
+
else:
|
109 |
+
with st.spinner("Модель обрабатывает текст..."):
|
110 |
+
predicted_label, inference_time, probabilities_dict = classify_text(user_text)
|
111 |
+
st.success(f"Предсказанный тэг: **{predicted_label}**")
|
112 |
+
st.info(f"Время инференса: {inference_time:.2f} секунд")
|
113 |
+
st.markdown("### Распределение вероятностей")
|
114 |
+
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
|
115 |
+
st.bar_chart(chart_data)
|
116 |
+
elif mode == "Загрузка файла":
|
117 |
+
uploaded_file = st.file_uploader("Загрузите текстовый файл (.txt)", type=["txt"])
|
118 |
+
if uploaded_file is not None:
|
119 |
+
file_text = uploaded_file.read().decode("utf-8")
|
120 |
+
st.text_area("Содержимое файла", file_text, height=150)
|
121 |
+
user_text = file_text
|
122 |
+
if st.button("Классифицировать", key="file"):
|
123 |
+
if user_text.strip() == "":
|
124 |
+
st.error("Пожалуйста, введите или загрузите текст для классификации.")
|
125 |
+
else:
|
126 |
+
with st.spinner("Модель обрабатывает текст..."):
|
127 |
+
predicted_label, inference_time, probabilities_dict = classify_text(user_text)
|
128 |
+
st.success(f"Предсказанный тэг: **{predicted_label}**")
|
129 |
+
st.info(f"Время инференса: {inference_time:.2f} секунд")
|
130 |
+
st.markdown("### Распределение вероятностей")
|
131 |
+
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
|
132 |
+
st.bar_chart(chart_data)
|
133 |
+
elif mode == "Пример":
|
134 |
+
sample_text = (
|
135 |
+
"Non-invertible symmetries of two-dimensional Non-Linear Sigma Models"
|
136 |
+
)
|
137 |
+
st.text_area("Пример текста", sample_text, height=150)
|
138 |
+
if st.button("Использовать пример", key="sample"):
|
139 |
+
user_text = sample_text
|
140 |
+
with st.spinner("Модель обрабатывает текст..."):
|
141 |
+
predicted_label, inference_time, probabilities_dict = classify_text(user_text)
|
142 |
st.success(f"Предсказанный тэг: **{predicted_label}**")
|
143 |
+
st.info(f"Время инференса: {inference_time:.2f} секунд")
|
144 |
+
st.markdown("### Распределение вероятностей")
|
145 |
+
chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
|
146 |
+
st.bar_chart(chart_data)
|
147 |
+
|
148 |
+
|
149 |
+
st.markdown("---")
|
150 |
+
st.markdown("### О модели")
|
151 |
+
st.write(
|
152 |
+
"Модель использует **DistilBERT** для классификации текста. "
|
153 |
+
"Данный демо-проект вдохновлён эстетикой студии Артемия Лебедева."
|
154 |
+
)
|
155 |
+
|
156 |
+
st.markdown("### Об авторе: физик-теоретик")
|
157 |
+
st.write(
|
158 |
+
"Этот проект был создан для 4 лабораторки."
|
159 |
+
)
|
160 |
+
|
161 |
+
st.sidebar.markdown("### Контакты")
|
162 |
+
st.sidebar.write("Есть вопросы или предложения? Пишите на [email protected]")
|