SDC-multi-classifier / classifier_app.py
DocUA's picture
Для деплоя на HF
9177daf
from sdc_classifier import SDCClassifier
from dotenv import load_dotenv
import torch
import json
import os
from typing import Dict, Tuple, Optional, Any, List
from dataclasses import dataclass, field
import pandas as pd
# Load environment variables
load_dotenv()
@dataclass
class Config:
# DEFAULT_CLASSES_FILE: str = "classes.json"
DEFAULT_CLASSES_FILE: str = "kw_questions.json"
DEFAULT_SIGNATURES_FILE: str = "signatures.npz"
CACHE_FILE: str = "embeddings_cache.db"
MODEL_INFO_FILE: str = "model_info.json"
DEFAULT_OPENAI_MODELS: List[str] = field(default_factory=lambda: ["text-embedding-3-large"])
DEFAULT_LOCAL_MODEL: str = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
config = Config()
class ClassifierApp:
def __init__(self):
self.classifier = None
self.initial_info = {
"status": "initializing",
"model_info": {},
"classes_info": {},
"errors": []
}
# self.model_type = "Local" # Додати цей рядок
self.model_type = "OpenAI" # Нова версія
def initialize_environment(self) -> Tuple[Dict, Optional[SDCClassifier]]:
"""Ініціалізація середовища при першому запуску"""
try:
# Перевіряємо наявність необхідних файлів
if not os.path.exists(config.DEFAULT_CLASSES_FILE):
self.initial_info["errors"].append(
f"ПОМИЛКА: Файл {config.DEFAULT_CLASSES_FILE} не знайдено!"
)
self.initial_info["status"] = "error"
print(f"\nПомилка: Файл {config.DEFAULT_CLASSES_FILE} не знайдено!")
return self.initial_info, None
print("\nСтворення класифікатора...")
try:
# Визначаємо яка модель використовувалась для сигнатур
signatures_model = None
if os.path.exists(config.MODEL_INFO_FILE):
with open(config.MODEL_INFO_FILE, 'r') as f:
model_info = json.load(f)
if not model_info.get('using_local', True):
signatures_model = "text-embedding-3-large" # Модель, яка використовувалась
# Створюємо класифікатор з тією ж моделлю
self.classifier = SDCClassifier(openai_api_key=os.getenv("OPENAI_API_KEY"))
print(f"Використовується модель: {signatures_model or 'local'}")
except Exception as e:
print(f"\nПомилка при створенні класифікатора: {str(e)}")
self.initial_info["errors"].append(f"Помилка при створенні класифікатора: {str(e)}")
self.initial_info["status"] = "error"
return self.initial_info, None
print("\nЗавантаження класів...")
try:
classes = self.classifier.load_classes(config.DEFAULT_CLASSES_FILE)
self.initial_info["classes_info"] = {
"total_classes": len(classes),
"classes_list": list(classes.keys()),
"hints_per_class": {cls: len(hints) for cls, hints in classes.items()}
}
except Exception as e:
print(f"\nПомилка при завантаженні класів: {str(e)}")
self.initial_info["errors"].append(f"Помилка при завантаженні класів: {str(e)}")
self.initial_info["status"] = "error"
return self.initial_info, None
print("\nПеревірка та завантаження сигнатур...")
if os.path.exists(config.DEFAULT_SIGNATURES_FILE):
try:
self.classifier.load_signatures(config.DEFAULT_SIGNATURES_FILE)
self.initial_info["status"] = "success"
print("Сигнатури завантажено успішно")
except Exception as e:
print(f"\nПомилка при завантаженні сигнатур: {str(e)}")
self.initial_info["errors"].append(f"Помилка при завантаженні сигнатур: {str(e)}")
self.initial_info["status"] = "error"
return self.initial_info, None
else:
print("\nСтворення нових сигнатур...")
self.initial_info["status"] = "creating_signatures"
try:
result = self.classifier.initialize_signatures(
force_rebuild=True,
signatures_file=config.DEFAULT_SIGNATURES_FILE
)
if isinstance(result, str) and "error" in result.lower():
print(f"\nПомилка при створенні сигнатур: {result}")
self.initial_info["errors"].append(result)
self.initial_info["status"] = "error"
return self.initial_info, None
except Exception as e:
print(f"\nПомилка при створенні сигнатур: {str(e)}")
self.initial_info["errors"].append(f"Помилка при створенні сигнатур: {str(e)}")
self.initial_info["status"] = "error"
return self.initial_info, None
print("\nЗбереження інформації про модель...")
try:
self.classifier.save_model_info(config.MODEL_INFO_FILE)
with open(config.MODEL_INFO_FILE, "r") as f:
self.initial_info["model_info"] = json.load(f)
self.initial_info["status"] = "success"
print("\nІніціалізація завершена успішно")
return self.initial_info, self.classifier
except Exception as e:
print(f"\nПомилка при збереженні інформації про модель: {str(e)}")
self.initial_info["errors"].append(f"Помилка при читанні model_info: {str(e)}")
self.initial_info["status"] = "error"
return self.initial_info, None
except Exception as e:
print(f"\nЗагальна помилка при ініціалізації: {str(e)}")
self.initial_info["errors"].append(f"ПОМИЛКА при ініціалізації: {str(e)}")
self.initial_info["status"] = "error"
return self.initial_info, None
def create_classifier(
self,
model_type: str,
openai_model: Optional[str] = None,
local_model: Optional[str] = None,
device: Optional[str] = None
) -> SDCClassifier:
"""Створення класифікатора з відповідними параметрами"""
classifier = SDCClassifier()
if model_type == "OpenAI":
if hasattr(classifier, 'set_openai_model'):
classifier.set_openai_model(openai_model)
else:
if hasattr(classifier, 'set_local_model'):
classifier.set_local_model(local_model, device)
return classifier
def update_model_inputs(
self,
model_type: str,
openai_model: str,
local_model: str,
device: str
) -> Dict[str, Any]:
"""Оновлення моделі та інтерфейсу при зміні типу моделі"""
try:
self.classifier = self.create_classifier(
model_type=model_type,
openai_model=openai_model if model_type == "OpenAI" else None,
local_model=local_model if model_type == "Local" else None,
device=device if model_type == "Local" else None
)
self.classifier.restore_base_state()
result = self.classifier.initialize_signatures()
self.classifier.save_model_info(config.MODEL_INFO_FILE)
with open(config.MODEL_INFO_FILE, "r") as f:
model_info = json.load(f)
new_system_info = {
"status": "success",
"model_info": model_info,
"classes_info": {
"total_classes": len(self.classifier.classes_json),
"classes_list": list(self.classifier.classes_json.keys()),
"hints_per_class": {cls: len(hints) for cls, hints in self.classifier.classes_json.items()}
},
"errors": []
}
return {
"model_choice": gr.update(visible=model_type == "OpenAI"),
"local_model_path": gr.update(visible=model_type == "Local"),
"device_choice": gr.update(visible=model_type == "Local"),
"system_info": new_system_info,
"system_md": self.update_system_markdown(new_system_info),
"build_out": f"Модель змінено на {model_type}",
"cache_stats": self.classifier.get_cache_stats()
}
except Exception as e:
error_info = {
"status": "error",
"errors": [str(e)],
"model_info": {},
"classes_info": {}
}
return {
"model_choice": gr.update(visible=model_type == "OpenAI"),
"local_model_path": gr.update(visible=model_type == "Local"),
"device_choice": gr.update(visible=model_type == "Local"),
"system_info": error_info,
"system_md": self.update_system_markdown(error_info),
"build_out": f"Помилка: {str(e)}",
"cache_stats": {}
}
def update_classifier_settings(
self,
json_file: Optional[str],
model_type: str,
openai_model: str,
local_model: str,
device: str,
force_rebuild: bool
) -> Tuple[str, Dict, Dict, str]:
"""Оновлення налаштувань класифікатора"""
try:
self.classifier = self.create_classifier(
model_type=model_type,
openai_model=openai_model if model_type == "OpenAI" else None,
local_model=local_model if model_type == "Local" else None,
device=device if model_type == "Local" else None
)
if json_file is not None:
with open(json_file.name, 'r', encoding='utf-8') as f:
new_classes = json.load(f)
self.classifier.load_classes(new_classes)
else:
self.classifier.restore_base_state()
result = self.classifier.initialize_signatures(
force_rebuild=force_rebuild,
signatures_file=config.DEFAULT_SIGNATURES_FILE if not force_rebuild else None
)
self.classifier.save_model_info(config.MODEL_INFO_FILE)
with open(config.MODEL_INFO_FILE, "r") as f:
model_info = json.load(f)
new_system_info = {
"status": "success",
"model_info": model_info,
"classes_info": {
"total_classes": len(self.classifier.classes_json),
"classes_list": list(self.classifier.classes_json.keys()),
"hints_per_class": {
cls: len(hints)
for cls, hints in self.classifier.classes_json.items()
}
},
"errors": []
}
return (
result,
self.classifier.get_cache_stats(),
new_system_info,
self.update_system_markdown(new_system_info)
)
except Exception as e:
error_info = {
"status": "error",
"errors": [str(e)],
"model_info": {},
"classes_info": {}
}
return (
f"Помилка: {str(e)}",
self.classifier.get_cache_stats(),
error_info,
self.update_system_markdown(error_info)
)
def process_single_text(self, text: str, threshold: float) -> Dict:
"""Обробка одного тексту"""
try:
if self.classifier is None:
raise ValueError("Класифікатор не ініціалізовано")
return self.classifier.process_single_text(text, threshold)
except Exception as e:
return {"error": str(e)}
def load_data(self, csv_path: str, emb_path: str) -> str:
"""Завантаження даних для пакетної обробки"""
try:
if self.classifier is None:
raise ValueError("Класифікатор не ініціалізовано")
return self.classifier.load_data(csv_path, emb_path)
except Exception as e:
return f"Помилка: {str(e)}"
def classify_batch(self, filter_str: str, threshold: float):
"""Пакетна класифікація"""
try:
if self.classifier is None:
raise ValueError("Класифікатор не ініціалізовано")
return self.classifier.classify_rows(filter_str, threshold)
except Exception as e:
return None
def save_results(self) -> str:
"""Збереження результатів"""
try:
if self.classifier is None:
raise ValueError("Класифікатор не ініціалізовано")
return self.classifier.save_results()
except Exception as e:
return f"Помилка: {str(e)}"
def sync_system_info(self) -> Dict:
"""Синхронізація системної інформації"""
try:
if self.classifier is None:
raise ValueError("Класифікатор не ініціалізовано")
self.classifier.save_model_info(config.MODEL_INFO_FILE)
with open(config.MODEL_INFO_FILE, "r") as f:
model_info = json.load(f)
self.initial_info = {
"status": "success",
"model_info": model_info,
"classes_info": {
"total_classes": len(self.classifier.classes_json),
"classes_list": list(self.classifier.classes_json.keys()),
"hints_per_class": {
cls: len(hints)
for cls, hints in self.classifier.classes_json.items()
}
},
"errors": []
}
return self.initial_info
except Exception as e:
self.initial_info = {
"status": "error",
"model_info": {},
"classes_info": {},
"errors": [str(e)]
}
return self.initial_info
def evaluate_batch(self, csv_file, threshold: float) -> tuple[pd.DataFrame, str]:
"""
Оцінка пакетної класифікації
Args:
csv_file: завантажений CSV файл від gradio
threshold: поріг впевненості
Returns:
tuple[pd.DataFrame, str]: результати та статистика
"""
try:
if self.classifier is None:
return None, "Помилка: Класифікатор не ініціалізовано"
# Перевірка на None
if csv_file is None:
return None, "Помилка: Файл не завантажено"
# Зберігаємо тимчасовий файл
temp_path = "temp_upload.csv"
if hasattr(csv_file, 'name'):
# Якщо це файловий об'єкт від gradio
import shutil
shutil.copy2(csv_file.name, temp_path)
else:
# Якщо це шлях до файлу
temp_path = str(csv_file)
# Виконуємо класифікацію
results_df, statistics = self.classifier.evaluate_classification(temp_path, threshold)
# Формуємо текст статистики
stats_md = f"""### Статистика класифікації
- Всього зразків: {statistics['total_samples']}
- Правильний клас на першому місці: {statistics['correct_first_place']['count']} ({statistics['correct_first_place']['percentage']}%)
- Правильний клас в топ-3: {statistics['in_top3']['count']} ({statistics['in_top3']['percentage']}%)
- Правильний клас не знайдено: {statistics['not_found']['count']} ({statistics['not_found']['percentage']}%)
#### Середня впевненість для правильних класифікацій: {statistics['mean_confidence_correct']}%
#### Розподіл впевненості:
- 90-100%: {statistics['confidence_distribution']['90-100%']['count']} ({statistics['confidence_distribution']['90-100%']['percentage']}%)
- 70-90%: {statistics['confidence_distribution']['70-90%']['count']} ({statistics['confidence_distribution']['70-90%']['percentage']}%)
- 50-70%: {statistics['confidence_distribution']['50-70%']['count']} ({statistics['confidence_distribution']['50-70%']['percentage']}%)
- <50%: {statistics['confidence_distribution']['<50%']['count']} ({statistics['confidence_distribution']['<50%']['percentage']}%)
"""
# Зберігаємо результати для подальшого використання
self.current_evaluation_results = results_df
# Видаляємо тимчасовий файл якщо він був створений
if temp_path == "temp_upload.csv" and os.path.exists(temp_path):
os.remove(temp_path)
return results_df, stats_md
except Exception as e:
# У випадку помилки спробуємо видалити тимчасовий файл
if os.path.exists("temp_upload.csv"):
os.remove("temp_upload.csv")
return None, f"Помилка: {str(e)}"
def save_evaluation_results(self) -> tuple[str, str]:
"""
Зберігає результати останньої оцінки класифікації та готує файл для завантаження
Returns:
tuple[str, str]: (шлях до файлу, повідомлення про статус)
"""
try:
if not hasattr(self, 'current_evaluation_results'):
return None, "Помилка: Немає результатів для збереження"
output_path = "evaluation_results.csv"
self.current_evaluation_results.to_csv(output_path, index=False)
return output_path, f"Результати збережено у файл {output_path}"
except Exception as e:
return None, f"Помилка при збереженні: {str(e)}"
@staticmethod
def update_system_markdown(info: Dict) -> str:
"""Оновлення Markdown з системною інформацією"""
if info["status"] == "success":
return f"""
### Поточна конфігурація:
- Модель: {info['model_info'].get('using_local', 'OpenAI')}
- Кількість класів: {info['classes_info']['total_classes']}
- Класи: {', '.join(info['classes_info']['classes_list'])}
"""
else:
return f"""
### Помилки ініціалізації:
{chr(10).join('- ' + err for err in info['errors'])}
"""