Spaces:
Runtime error
Runtime error
from sdc_classifier import SDCClassifier | |
from dotenv import load_dotenv | |
import torch | |
import json | |
import os | |
from typing import Dict, Tuple, Optional, Any, List | |
from dataclasses import dataclass, field | |
import pandas as pd | |
# Load environment variables | |
load_dotenv() | |
class Config: | |
# DEFAULT_CLASSES_FILE: str = "classes.json" | |
DEFAULT_CLASSES_FILE: str = "kw_questions.json" | |
DEFAULT_SIGNATURES_FILE: str = "signatures.npz" | |
CACHE_FILE: str = "embeddings_cache.db" | |
MODEL_INFO_FILE: str = "model_info.json" | |
DEFAULT_OPENAI_MODELS: List[str] = field(default_factory=lambda: ["text-embedding-3-large"]) | |
DEFAULT_LOCAL_MODEL: str = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext" | |
config = Config() | |
class ClassifierApp: | |
def __init__(self): | |
self.classifier = None | |
self.initial_info = { | |
"status": "initializing", | |
"model_info": {}, | |
"classes_info": {}, | |
"errors": [] | |
} | |
# self.model_type = "Local" # Додати цей рядок | |
self.model_type = "OpenAI" # Нова версія | |
def initialize_environment(self) -> Tuple[Dict, Optional[SDCClassifier]]: | |
"""Ініціалізація середовища при першому запуску""" | |
try: | |
# Перевіряємо наявність необхідних файлів | |
if not os.path.exists(config.DEFAULT_CLASSES_FILE): | |
self.initial_info["errors"].append( | |
f"ПОМИЛКА: Файл {config.DEFAULT_CLASSES_FILE} не знайдено!" | |
) | |
self.initial_info["status"] = "error" | |
print(f"\nПомилка: Файл {config.DEFAULT_CLASSES_FILE} не знайдено!") | |
return self.initial_info, None | |
print("\nСтворення класифікатора...") | |
try: | |
# Визначаємо яка модель використовувалась для сигнатур | |
signatures_model = None | |
if os.path.exists(config.MODEL_INFO_FILE): | |
with open(config.MODEL_INFO_FILE, 'r') as f: | |
model_info = json.load(f) | |
if not model_info.get('using_local', True): | |
signatures_model = "text-embedding-3-large" # Модель, яка використовувалась | |
# Створюємо класифікатор з тією ж моделлю | |
self.classifier = SDCClassifier(openai_api_key=os.getenv("OPENAI_API_KEY")) | |
print(f"Використовується модель: {signatures_model or 'local'}") | |
except Exception as e: | |
print(f"\nПомилка при створенні класифікатора: {str(e)}") | |
self.initial_info["errors"].append(f"Помилка при створенні класифікатора: {str(e)}") | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
print("\nЗавантаження класів...") | |
try: | |
classes = self.classifier.load_classes(config.DEFAULT_CLASSES_FILE) | |
self.initial_info["classes_info"] = { | |
"total_classes": len(classes), | |
"classes_list": list(classes.keys()), | |
"hints_per_class": {cls: len(hints) for cls, hints in classes.items()} | |
} | |
except Exception as e: | |
print(f"\nПомилка при завантаженні класів: {str(e)}") | |
self.initial_info["errors"].append(f"Помилка при завантаженні класів: {str(e)}") | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
print("\nПеревірка та завантаження сигнатур...") | |
if os.path.exists(config.DEFAULT_SIGNATURES_FILE): | |
try: | |
self.classifier.load_signatures(config.DEFAULT_SIGNATURES_FILE) | |
self.initial_info["status"] = "success" | |
print("Сигнатури завантажено успішно") | |
except Exception as e: | |
print(f"\nПомилка при завантаженні сигнатур: {str(e)}") | |
self.initial_info["errors"].append(f"Помилка при завантаженні сигнатур: {str(e)}") | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
else: | |
print("\nСтворення нових сигнатур...") | |
self.initial_info["status"] = "creating_signatures" | |
try: | |
result = self.classifier.initialize_signatures( | |
force_rebuild=True, | |
signatures_file=config.DEFAULT_SIGNATURES_FILE | |
) | |
if isinstance(result, str) and "error" in result.lower(): | |
print(f"\nПомилка при створенні сигнатур: {result}") | |
self.initial_info["errors"].append(result) | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
except Exception as e: | |
print(f"\nПомилка при створенні сигнатур: {str(e)}") | |
self.initial_info["errors"].append(f"Помилка при створенні сигнатур: {str(e)}") | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
print("\nЗбереження інформації про модель...") | |
try: | |
self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
with open(config.MODEL_INFO_FILE, "r") as f: | |
self.initial_info["model_info"] = json.load(f) | |
self.initial_info["status"] = "success" | |
print("\nІніціалізація завершена успішно") | |
return self.initial_info, self.classifier | |
except Exception as e: | |
print(f"\nПомилка при збереженні інформації про модель: {str(e)}") | |
self.initial_info["errors"].append(f"Помилка при читанні model_info: {str(e)}") | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
except Exception as e: | |
print(f"\nЗагальна помилка при ініціалізації: {str(e)}") | |
self.initial_info["errors"].append(f"ПОМИЛКА при ініціалізації: {str(e)}") | |
self.initial_info["status"] = "error" | |
return self.initial_info, None | |
def create_classifier( | |
self, | |
model_type: str, | |
openai_model: Optional[str] = None, | |
local_model: Optional[str] = None, | |
device: Optional[str] = None | |
) -> SDCClassifier: | |
"""Створення класифікатора з відповідними параметрами""" | |
classifier = SDCClassifier() | |
if model_type == "OpenAI": | |
if hasattr(classifier, 'set_openai_model'): | |
classifier.set_openai_model(openai_model) | |
else: | |
if hasattr(classifier, 'set_local_model'): | |
classifier.set_local_model(local_model, device) | |
return classifier | |
def update_model_inputs( | |
self, | |
model_type: str, | |
openai_model: str, | |
local_model: str, | |
device: str | |
) -> Dict[str, Any]: | |
"""Оновлення моделі та інтерфейсу при зміні типу моделі""" | |
try: | |
self.classifier = self.create_classifier( | |
model_type=model_type, | |
openai_model=openai_model if model_type == "OpenAI" else None, | |
local_model=local_model if model_type == "Local" else None, | |
device=device if model_type == "Local" else None | |
) | |
self.classifier.restore_base_state() | |
result = self.classifier.initialize_signatures() | |
self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
with open(config.MODEL_INFO_FILE, "r") as f: | |
model_info = json.load(f) | |
new_system_info = { | |
"status": "success", | |
"model_info": model_info, | |
"classes_info": { | |
"total_classes": len(self.classifier.classes_json), | |
"classes_list": list(self.classifier.classes_json.keys()), | |
"hints_per_class": {cls: len(hints) for cls, hints in self.classifier.classes_json.items()} | |
}, | |
"errors": [] | |
} | |
return { | |
"model_choice": gr.update(visible=model_type == "OpenAI"), | |
"local_model_path": gr.update(visible=model_type == "Local"), | |
"device_choice": gr.update(visible=model_type == "Local"), | |
"system_info": new_system_info, | |
"system_md": self.update_system_markdown(new_system_info), | |
"build_out": f"Модель змінено на {model_type}", | |
"cache_stats": self.classifier.get_cache_stats() | |
} | |
except Exception as e: | |
error_info = { | |
"status": "error", | |
"errors": [str(e)], | |
"model_info": {}, | |
"classes_info": {} | |
} | |
return { | |
"model_choice": gr.update(visible=model_type == "OpenAI"), | |
"local_model_path": gr.update(visible=model_type == "Local"), | |
"device_choice": gr.update(visible=model_type == "Local"), | |
"system_info": error_info, | |
"system_md": self.update_system_markdown(error_info), | |
"build_out": f"Помилка: {str(e)}", | |
"cache_stats": {} | |
} | |
def update_classifier_settings( | |
self, | |
json_file: Optional[str], | |
model_type: str, | |
openai_model: str, | |
local_model: str, | |
device: str, | |
force_rebuild: bool | |
) -> Tuple[str, Dict, Dict, str]: | |
"""Оновлення налаштувань класифікатора""" | |
try: | |
self.classifier = self.create_classifier( | |
model_type=model_type, | |
openai_model=openai_model if model_type == "OpenAI" else None, | |
local_model=local_model if model_type == "Local" else None, | |
device=device if model_type == "Local" else None | |
) | |
if json_file is not None: | |
with open(json_file.name, 'r', encoding='utf-8') as f: | |
new_classes = json.load(f) | |
self.classifier.load_classes(new_classes) | |
else: | |
self.classifier.restore_base_state() | |
result = self.classifier.initialize_signatures( | |
force_rebuild=force_rebuild, | |
signatures_file=config.DEFAULT_SIGNATURES_FILE if not force_rebuild else None | |
) | |
self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
with open(config.MODEL_INFO_FILE, "r") as f: | |
model_info = json.load(f) | |
new_system_info = { | |
"status": "success", | |
"model_info": model_info, | |
"classes_info": { | |
"total_classes": len(self.classifier.classes_json), | |
"classes_list": list(self.classifier.classes_json.keys()), | |
"hints_per_class": { | |
cls: len(hints) | |
for cls, hints in self.classifier.classes_json.items() | |
} | |
}, | |
"errors": [] | |
} | |
return ( | |
result, | |
self.classifier.get_cache_stats(), | |
new_system_info, | |
self.update_system_markdown(new_system_info) | |
) | |
except Exception as e: | |
error_info = { | |
"status": "error", | |
"errors": [str(e)], | |
"model_info": {}, | |
"classes_info": {} | |
} | |
return ( | |
f"Помилка: {str(e)}", | |
self.classifier.get_cache_stats(), | |
error_info, | |
self.update_system_markdown(error_info) | |
) | |
def process_single_text(self, text: str, threshold: float) -> Dict: | |
"""Обробка одного тексту""" | |
try: | |
if self.classifier is None: | |
raise ValueError("Класифікатор не ініціалізовано") | |
return self.classifier.process_single_text(text, threshold) | |
except Exception as e: | |
return {"error": str(e)} | |
def load_data(self, csv_path: str, emb_path: str) -> str: | |
"""Завантаження даних для пакетної обробки""" | |
try: | |
if self.classifier is None: | |
raise ValueError("Класифікатор не ініціалізовано") | |
return self.classifier.load_data(csv_path, emb_path) | |
except Exception as e: | |
return f"Помилка: {str(e)}" | |
def classify_batch(self, filter_str: str, threshold: float): | |
"""Пакетна класифікація""" | |
try: | |
if self.classifier is None: | |
raise ValueError("Класифікатор не ініціалізовано") | |
return self.classifier.classify_rows(filter_str, threshold) | |
except Exception as e: | |
return None | |
def save_results(self) -> str: | |
"""Збереження результатів""" | |
try: | |
if self.classifier is None: | |
raise ValueError("Класифікатор не ініціалізовано") | |
return self.classifier.save_results() | |
except Exception as e: | |
return f"Помилка: {str(e)}" | |
def sync_system_info(self) -> Dict: | |
"""Синхронізація системної інформації""" | |
try: | |
if self.classifier is None: | |
raise ValueError("Класифікатор не ініціалізовано") | |
self.classifier.save_model_info(config.MODEL_INFO_FILE) | |
with open(config.MODEL_INFO_FILE, "r") as f: | |
model_info = json.load(f) | |
self.initial_info = { | |
"status": "success", | |
"model_info": model_info, | |
"classes_info": { | |
"total_classes": len(self.classifier.classes_json), | |
"classes_list": list(self.classifier.classes_json.keys()), | |
"hints_per_class": { | |
cls: len(hints) | |
for cls, hints in self.classifier.classes_json.items() | |
} | |
}, | |
"errors": [] | |
} | |
return self.initial_info | |
except Exception as e: | |
self.initial_info = { | |
"status": "error", | |
"model_info": {}, | |
"classes_info": {}, | |
"errors": [str(e)] | |
} | |
return self.initial_info | |
def evaluate_batch(self, csv_file, threshold: float) -> tuple[pd.DataFrame, str]: | |
""" | |
Оцінка пакетної класифікації | |
Args: | |
csv_file: завантажений CSV файл від gradio | |
threshold: поріг впевненості | |
Returns: | |
tuple[pd.DataFrame, str]: результати та статистика | |
""" | |
try: | |
if self.classifier is None: | |
return None, "Помилка: Класифікатор не ініціалізовано" | |
# Перевірка на None | |
if csv_file is None: | |
return None, "Помилка: Файл не завантажено" | |
# Зберігаємо тимчасовий файл | |
temp_path = "temp_upload.csv" | |
if hasattr(csv_file, 'name'): | |
# Якщо це файловий об'єкт від gradio | |
import shutil | |
shutil.copy2(csv_file.name, temp_path) | |
else: | |
# Якщо це шлях до файлу | |
temp_path = str(csv_file) | |
# Виконуємо класифікацію | |
results_df, statistics = self.classifier.evaluate_classification(temp_path, threshold) | |
# Формуємо текст статистики | |
stats_md = f"""### Статистика класифікації | |
- Всього зразків: {statistics['total_samples']} | |
- Правильний клас на першому місці: {statistics['correct_first_place']['count']} ({statistics['correct_first_place']['percentage']}%) | |
- Правильний клас в топ-3: {statistics['in_top3']['count']} ({statistics['in_top3']['percentage']}%) | |
- Правильний клас не знайдено: {statistics['not_found']['count']} ({statistics['not_found']['percentage']}%) | |
#### Середня впевненість для правильних класифікацій: {statistics['mean_confidence_correct']}% | |
#### Розподіл впевненості: | |
- 90-100%: {statistics['confidence_distribution']['90-100%']['count']} ({statistics['confidence_distribution']['90-100%']['percentage']}%) | |
- 70-90%: {statistics['confidence_distribution']['70-90%']['count']} ({statistics['confidence_distribution']['70-90%']['percentage']}%) | |
- 50-70%: {statistics['confidence_distribution']['50-70%']['count']} ({statistics['confidence_distribution']['50-70%']['percentage']}%) | |
- <50%: {statistics['confidence_distribution']['<50%']['count']} ({statistics['confidence_distribution']['<50%']['percentage']}%) | |
""" | |
# Зберігаємо результати для подальшого використання | |
self.current_evaluation_results = results_df | |
# Видаляємо тимчасовий файл якщо він був створений | |
if temp_path == "temp_upload.csv" and os.path.exists(temp_path): | |
os.remove(temp_path) | |
return results_df, stats_md | |
except Exception as e: | |
# У випадку помилки спробуємо видалити тимчасовий файл | |
if os.path.exists("temp_upload.csv"): | |
os.remove("temp_upload.csv") | |
return None, f"Помилка: {str(e)}" | |
def save_evaluation_results(self) -> tuple[str, str]: | |
""" | |
Зберігає результати останньої оцінки класифікації та готує файл для завантаження | |
Returns: | |
tuple[str, str]: (шлях до файлу, повідомлення про статус) | |
""" | |
try: | |
if not hasattr(self, 'current_evaluation_results'): | |
return None, "Помилка: Немає результатів для збереження" | |
output_path = "evaluation_results.csv" | |
self.current_evaluation_results.to_csv(output_path, index=False) | |
return output_path, f"Результати збережено у файл {output_path}" | |
except Exception as e: | |
return None, f"Помилка при збереженні: {str(e)}" | |
def update_system_markdown(info: Dict) -> str: | |
"""Оновлення Markdown з системною інформацією""" | |
if info["status"] == "success": | |
return f""" | |
### Поточна конфігурація: | |
- Модель: {info['model_info'].get('using_local', 'OpenAI')} | |
- Кількість класів: {info['classes_info']['total_classes']} | |
- Класи: {', '.join(info['classes_info']['classes_list'])} | |
""" | |
else: | |
return f""" | |
### Помилки ініціалізації: | |
{chr(10).join('- ' + err for err in info['errors'])} | |
""" |