fpopov1993 commited on
Commit
93fdb7c
·
verified ·
1 Parent(s): d6af057

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -18
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import streamlit as st
2
  import torch
 
 
3
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
 
5
- # Кэширование загрузки модели для ускорения перезапуска приложения.
6
  @st.cache_resource
7
  def load_model_and_tokenizer():
8
  model = DistilBertForSequenceClassification.from_pretrained(
@@ -14,25 +16,32 @@ def load_model_and_tokenizer():
14
 
15
  model, tokenizer = load_model_and_tokenizer()
16
 
17
- def classify_text(text: str) -> str:
18
- """Токенизация текста и запуск инференса."""
19
  encoding = tokenizer(
20
  text, return_tensors="pt", padding=True, truncation=True, max_length=128
21
  )
 
22
  with torch.no_grad():
23
  outputs = model(**encoding)
24
  logits = outputs.logits
 
 
 
25
  predicted_class_id = torch.argmax(logits, dim=1).item()
26
- id2label = model.config.id2label # Предполагается, что id2label был задан при обучении.
27
  predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id)
28
- return predicted_label
 
 
 
 
 
29
 
30
- # Добавление кастомного CSS для стилизации в духе Артемия Лебедева.
31
  st.markdown(
32
  """
33
  <style>
34
  @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap');
35
-
36
  body {
37
  background-color: #1a1a1a;
38
  color: #f5f5f5;
@@ -70,19 +79,84 @@ st.markdown(
70
  """, unsafe_allow_html=True
71
  )
72
 
73
- # Заголовки и описание приложения.
74
  st.markdown("<h1>Классификация тэгов в статьях архива</h1>", unsafe_allow_html=True)
75
  st.markdown("<h2>Дизайн почти как от студии Артемия Лебедева</h2>", unsafe_allow_html=True)
76
- st.write("Введите название статьи или её абстракт и мы подберём подходящий тэг.")
77
 
78
- # Текстовые поля для ввода названия и абстракта.
79
- title = st.text_area("Название", "")
80
- abstract = st.text_area("Абстракт", "")
81
 
82
- if st.button("Классифицировать"):
83
- user_text = title.strip() + " " + abstract.strip()
84
- if user_text.strip() == "":
85
- st.error("Пожалуйста, введите текст для классификации.")
86
- else:
87
- predicted_label = classify_text(user_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  st.success(f"Предсказанный тэг: **{predicted_label}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ import time
4
+ import pandas as pd
5
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
6
 
7
+
8
  @st.cache_resource
9
  def load_model_and_tokenizer():
10
  model = DistilBertForSequenceClassification.from_pretrained(
 
16
 
17
  model, tokenizer = load_model_and_tokenizer()
18
 
19
+ def classify_text(text: str) -> (str, float, dict):
20
+ """Токенизация текста, запуск инференса, измерение времени выполнения и визуализация вероятностей."""
21
  encoding = tokenizer(
22
  text, return_tensors="pt", padding=True, truncation=True, max_length=128
23
  )
24
+ start_time = time.time()
25
  with torch.no_grad():
26
  outputs = model(**encoding)
27
  logits = outputs.logits
28
+ elapsed_time = time.time() - start_time
29
+ probabilities_tensor = torch.softmax(logits, dim=1).squeeze()
30
+ probabilities = probabilities_tensor.tolist()
31
  predicted_class_id = torch.argmax(logits, dim=1).item()
32
+ id2label = model.config.id2label # Предполагается, что id2label задан при обучении.
33
  predicted_label = id2label[predicted_class_id] if id2label else str(predicted_class_id)
34
+ if id2label:
35
+ probabilities_dict = { id2label[i]: probabilities[i] for i in range(len(probabilities)) }
36
+ else:
37
+ probabilities_dict = { str(i): probabilities[i] for i in range(len(probabilities)) }
38
+ return predicted_label, elapsed_time, probabilities_dict
39
+
40
 
 
41
  st.markdown(
42
  """
43
  <style>
44
  @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap');
 
45
  body {
46
  background-color: #1a1a1a;
47
  color: #f5f5f5;
 
79
  """, unsafe_allow_html=True
80
  )
81
 
82
+
83
  st.markdown("<h1>Классификация тэгов в статьях архива</h1>", unsafe_allow_html=True)
84
  st.markdown("<h2>Дизайн почти как от студии Артемия Лебедева</h2>", unsafe_allow_html=True)
85
+ st.write("Это демо-приложение позволяет классифицировать текст статьи или абстракта с помощью модели DistilBERT. Был использован весь архив от 2016 года.")
86
 
 
 
 
87
 
88
+ st.sidebar.title("Опции")
89
+ mode = st.sidebar.radio("Выберите режим ввода:", ["Ручной ввод", "Загрузка файла", "Пример"])
90
+ st.sidebar.markdown("---")
91
+ st.sidebar.info(
92
+ "Вы можете ввести текст вручную, загрузить файл (.txt) или использовать готовый пример."
93
+ )
94
+ st.sidebar.markdown("---")
95
+ if st.sidebar.button("Очистить кэш"):
96
+ st.cache_resource.clear() # Сброс кэша (при перезапуске приложения)
97
+ st.sidebar.success("Кэш очищен!")
98
+
99
+
100
+ user_text = ""
101
+ if mode == "Ручной ввод":
102
+ title_input = st.text_area("Название (можно пропустить)", "", height=100)
103
+ abstract_input = st.text_area("Абстракт (тоже не обязателен)", "", height=150)
104
+ user_text = title_input.strip() + " " + abstract_input.strip()
105
+ if st.button("Классифицировать", key="manual"):
106
+ if user_text.strip() == "":
107
+ st.error("Пожалуйста, введите или загрузите текст для классификации.")
108
+ else:
109
+ with st.spinner("Модель обрабатывает текст..."):
110
+ predicted_label, inference_time, probabilities_dict = classify_text(user_text)
111
+ st.success(f"Предсказанный тэг: **{predicted_label}**")
112
+ st.info(f"Время инференса: {inference_time:.2f} секунд")
113
+ st.markdown("### Распределение вероятностей")
114
+ chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
115
+ st.bar_chart(chart_data)
116
+ elif mode == "Загрузка файла":
117
+ uploaded_file = st.file_uploader("Загрузите текстовый файл (.txt)", type=["txt"])
118
+ if uploaded_file is not None:
119
+ file_text = uploaded_file.read().decode("utf-8")
120
+ st.text_area("Содержимое файла", file_text, height=150)
121
+ user_text = file_text
122
+ if st.button("Классифицировать", key="file"):
123
+ if user_text.strip() == "":
124
+ st.error("Пожалуйста, введите или загрузите текст для классификации.")
125
+ else:
126
+ with st.spinner("Модель обрабатывает текст..."):
127
+ predicted_label, inference_time, probabilities_dict = classify_text(user_text)
128
+ st.success(f"Предсказанный тэг: **{predicted_label}**")
129
+ st.info(f"Время инференса: {inference_time:.2f} секунд")
130
+ st.markdown("### Распределение вероятностей")
131
+ chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
132
+ st.bar_chart(chart_data)
133
+ elif mode == "Пример":
134
+ sample_text = (
135
+ "Non-invertible symmetries of two-dimensional Non-Linear Sigma Models"
136
+ )
137
+ st.text_area("Пример текста", sample_text, height=150)
138
+ if st.button("Использовать пример", key="sample"):
139
+ user_text = sample_text
140
+ with st.spinner("Модель обрабатывает текст..."):
141
+ predicted_label, inference_time, probabilities_dict = classify_text(user_text)
142
  st.success(f"Предсказанный тэг: **{predicted_label}**")
143
+ st.info(f"Время инференса: {inference_time:.2f} секунд")
144
+ st.markdown("### Распределение вероятностей")
145
+ chart_data = pd.DataFrame(probabilities_dict.items(), columns=["Класс", "Вероятность"]).set_index("Класс")
146
+ st.bar_chart(chart_data)
147
+
148
+
149
+ st.markdown("---")
150
+ st.markdown("### О модели")
151
+ st.write(
152
+ "Модель использует **DistilBERT** для классификации текста. "
153
+ "Данный демо-проект вдохновлён эстетикой студии Артемия Лебедева."
154
+ )
155
+
156
+ st.markdown("### Об авторе: физик-теоретик")
157
+ st.write(
158
+ "Этот проект был создан для 4 лабораторки."
159
+ )
160
+
161
+ st.sidebar.markdown("### Контакты")
162
+ st.sidebar.write("Есть вопросы или предложения? Пишите на [email protected]")