Spaces:
Sleeping
Sleeping
| # utils/tilmash_translation.py | |
| import logging | |
| import re | |
| import os | |
| import threading | |
| import time | |
| import uuid | |
| from dotenv import load_dotenv | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TranslationPipeline | |
| from .chunking import chunk_text_with_separators | |
| from huggingface_hub import login | |
| from typing import Iterator | |
| from config import DEFAULT_CONFIG | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| hf_token = os.getenv('HF_TOKEN') | |
| if not hf_token: | |
| logging.warning("HF_TOKEN not found in environment variables. Model downloading might fail.") | |
| else: | |
| login(token=hf_token) | |
| # Global tilmash lock file | |
| LOCK_DIR = os.path.join("local_llms", "locks") | |
| os.makedirs(LOCK_DIR, exist_ok=True) | |
| TILMASH_LOCK_FILE = os.path.join(LOCK_DIR, "tilmash.lock") | |
| # Get session timeout from config | |
| SESSION_TIMEOUT = DEFAULT_CONFIG["SESSION_TIMEOUT"] | |
| class ExclusiveResourceLock: | |
| """File-based lock for exclusive GPU resource access across processes.""" | |
| def __init__(self, lock_file, timeout=SESSION_TIMEOUT): | |
| self.lock_file = lock_file | |
| self.timeout = timeout | |
| self.lock_id = str(uuid.uuid4()) | |
| self.acquired = False | |
| def acquire(self): | |
| """Acquire exclusive lock with timeout.""" | |
| start_time = time.time() | |
| while time.time() - start_time < self.timeout: | |
| try: | |
| # Try to create the lock file | |
| if not os.path.exists(self.lock_file): | |
| with open(self.lock_file, 'w') as f: | |
| f.write(f"{self.lock_id}\n{os.getpid()}\n{time.time()}") | |
| # Verify we got the lock | |
| with open(self.lock_file, 'r') as f: | |
| content = f.read().split('\n') | |
| if content and content[0] == self.lock_id: | |
| self.acquired = True | |
| return True | |
| # Check if lock file is stale (older than 5 minutes) | |
| elif os.path.exists(self.lock_file): | |
| lock_time = os.path.getmtime(self.lock_file) | |
| if time.time() - lock_time > 300: # 5 minutes | |
| try: | |
| # Remove stale lock | |
| os.remove(self.lock_file) | |
| continue | |
| except: | |
| pass | |
| # Wait before retrying | |
| time.sleep(1) | |
| except Exception as e: | |
| logging.error(f"Lock acquisition error: {str(e)}") | |
| time.sleep(1) | |
| return False | |
| def release(self): | |
| """Release the lock if we own it.""" | |
| if not self.acquired: | |
| return | |
| try: | |
| if os.path.exists(self.lock_file): | |
| with open(self.lock_file, 'r') as f: | |
| content = f.read().split('\n') | |
| if content and content[0] == self.lock_id: | |
| os.remove(self.lock_file) | |
| self.acquired = False | |
| except Exception as e: | |
| logging.error(f"Lock release error: {str(e)}") | |
| def __enter__(self): | |
| self.acquire() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.release() | |
| class TilmashTranslator: | |
| """ | |
| Thread-safe translator using Tilmash model | |
| """ | |
| def __init__(self, use_gpu=None): | |
| """Initialize the Tilmash translator.""" | |
| # Use thread-local lock | |
| self._lock = threading.RLock() | |
| self.initialized = False | |
| self.model = None | |
| self.tokenizer = None | |
| # Get session ID | |
| import streamlit as st | |
| self.session_id = getattr(st.session_state, 'session_id', str(uuid.uuid4())) | |
| # Определяем, использовать ли GPU | |
| self.use_gpu = use_gpu | |
| self.device = "cpu" | |
| # Проверяем доступность GPU через PyTorch | |
| if use_gpu is not False: # None или True | |
| try: | |
| import torch | |
| logging.info(f"Проверка GPU: use_gpu={use_gpu}, PyTorch версия {torch.__version__}") | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| gpu_info = torch.cuda.get_device_name(0) | |
| cuda_ver = torch.version.cuda | |
| logging.info(f"CUDA GPU доступен и будет использован для Tilmash: {gpu_info}, CUDA {cuda_ver}") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| self.device = "mps" # для Apple Silicon | |
| logging.info("Apple MPS доступен и будет использован для Tilmash") | |
| else: | |
| logging.warning("GPU не обнаружен. Диагностика: ") | |
| if hasattr(torch.cuda, "is_available"): | |
| logging.warning(f"torch.cuda.is_available(): {torch.cuda.is_available()}") | |
| if hasattr(torch, "version") and hasattr(torch.version, "cuda"): | |
| logging.warning(f"torch.version.cuda: {torch.version.cuda}") | |
| if hasattr(torch, "__config__"): | |
| logging.warning(f"PyTorch конфигурация: {torch.__config__.show()}") | |
| if hasattr(torch.cuda, "get_arch_list") and callable(torch.cuda.get_arch_list): | |
| logging.warning(f"CUDA архитектуры: {torch.cuda.get_arch_list()}") | |
| except ImportError as e: | |
| logging.warning(f"PyTorch не установлен или не может быть импортирован: {str(e)}") | |
| except Exception as e: | |
| logging.warning(f"Ошибка при проверке GPU: {str(e)}") | |
| def load_model(self): | |
| """Load the Tilmash model if not already loaded.""" | |
| with self._lock: | |
| if self.initialized: | |
| return self.model, self.tokenizer | |
| try: | |
| model_name = "issai/tilmash" | |
| cache_dir = "local_llms" | |
| # Ensure cache directory exists | |
| os.makedirs(cache_dir, exist_ok=True) | |
| try: | |
| # First try to load the model locally | |
| logging.info(f"Loading Tilmash model for session {self.session_id[:8]} on {self.device}...") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| local_files_only=True | |
| ) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| local_files_only=True | |
| ) | |
| # Перемещаем модель на нужное устройство (CPU/GPU) | |
| import torch | |
| self.model = self.model.to(self.device) | |
| # Если используем GPU, можно включить half-precision для ускорения | |
| if self.device in ["cuda", "mps"]: | |
| self.model = self.model.half() | |
| logging.info(f"Successfully loaded model from local cache on {self.device}.") | |
| except OSError: | |
| # If local loading fails, download the model | |
| logging.info(f"Model not found locally. Downloading from Hugging Face to {self.device}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| local_files_only=False | |
| ) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| local_files_only=False | |
| ) | |
| # Перемещаем модель на нужное устройство (CPU/GPU) | |
| import torch | |
| self.model = self.model.to(self.device) | |
| # Если используем GPU, можно включить half-precision для ускорения | |
| if self.device in ["cuda", "mps"]: | |
| self.model = self.model.half() | |
| logging.info(f"Successfully downloaded and loaded the model on {self.device}.") | |
| self.initialized = True | |
| return self.model, self.tokenizer | |
| except ValueError as e: | |
| logging.error(f"Invalid model configuration: {str(e)}") | |
| raise ValueError(f"Failed to load model: {str(e)}") | |
| except Exception as e: | |
| logging.error(f"Unexpected error during model initialization: {str(e)}") | |
| raise Exception(f"Failed to load model: {str(e)}") | |
| except Exception as e: | |
| logging.error(f"Failed to load Tilmash model: {str(e)}") | |
| raise | |
| def unload_model(self): | |
| """Unload the model to free memory""" | |
| with self._lock: | |
| if self.initialized: | |
| logging.info("Unloading Tilmash model to free memory...") | |
| self.model = None | |
| self.tokenizer = None | |
| self.initialized = False | |
| # Force garbage collection | |
| import gc | |
| gc.collect() | |
| logging.info("Tilmash model unloaded") | |
| def create_pipeline(self, src_lang, tgt_lang, max_length=512): | |
| """Create a translation pipeline with the loaded model.""" | |
| with self._lock: | |
| lang_map = { | |
| 'ru': 'rus_Cyrl', | |
| 'en': 'eng_Latn', | |
| 'kk': 'kaz_Cyrl' | |
| } | |
| # Validate language pair | |
| if src_lang not in lang_map or tgt_lang not in lang_map: | |
| raise ValueError(f"Unsupported language pair: {src_lang} -> {tgt_lang}") | |
| # Make sure model is loaded | |
| if not self.initialized: | |
| self.load_model() | |
| # Configure translation pipeline with optimized parameters | |
| pipeline = TranslationPipeline( | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| src_lang=lang_map[src_lang], | |
| tgt_lang=lang_map[tgt_lang], | |
| max_length=max_length, | |
| num_beams=7, | |
| early_stopping=True, | |
| repetition_penalty=1.3, | |
| no_repeat_ngram_size=2, | |
| length_penalty=1.1, | |
| truncation=True, | |
| clean_up_tokenization_spaces=True, | |
| device=self.device # Явно указываем устройство для пайплайна | |
| ) | |
| return pipeline | |
| def translate(self, text, src_lang, tgt_lang, max_length=512): | |
| """Translate text using the Tilmash model.""" | |
| with self._lock: | |
| try: | |
| pipeline = self.create_pipeline(src_lang, tgt_lang, max_length) | |
| # Split text into sentences for better quality | |
| sentences = re.split(r'(?<=[.!?]) +', text) | |
| translated_sentences = [] | |
| for sentence in sentences: | |
| if sentence.strip(): | |
| result = pipeline(sentence) | |
| translated_sentence = _extract_translation(result) | |
| translated_sentences.append(translated_sentence) | |
| return ' '.join(translated_sentences) | |
| except Exception as e: | |
| logging.error(f"Translation error: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def translate_streaming(self, text, src_lang, tgt_lang, max_length=512) -> Iterator[str]: | |
| """Stream translation results sentence by sentence.""" | |
| try: | |
| # Make sure model is loaded - must be done in the locked section | |
| with self._lock: | |
| if not self.initialized: | |
| self.load_model() | |
| pipeline = self.create_pipeline(src_lang, tgt_lang, max_length) | |
| # Check if text is too large for single processing | |
| # Improved text size detection - check by paragraphs | |
| paragraphs = re.split(r'\n\s*\n', text) | |
| is_large_text = len(paragraphs) > 3 or len(text) > 1000 # Multiple paragraphs or long text | |
| if is_large_text: | |
| # Process paragraph by paragraph for structured documents | |
| for i, paragraph in enumerate(paragraphs): | |
| if not paragraph.strip(): | |
| yield "\n\n" | |
| continue | |
| # If paragraph itself is too large, process it sentence by sentence | |
| if len(paragraph) > 800: | |
| sentences = re.split(r'(?<=[.!?])\s+', paragraph) | |
| for sentence in sentences: | |
| if not sentence.strip(): | |
| continue | |
| try: | |
| # Only lock the actual model inference | |
| with self._lock: | |
| result = pipeline(sentence) | |
| translated = _extract_translation(result) | |
| yield translated + " " | |
| except Exception as e: | |
| logging.error(f"Error translating sentence: {str(e)}") | |
| yield f"[Error: {str(e)}] " | |
| else: | |
| # Process whole paragraph at once | |
| try: | |
| # Only lock the actual model inference | |
| with self._lock: | |
| result = pipeline(paragraph) | |
| translated = _extract_translation(result) | |
| yield translated | |
| # Add paragraph break after each paragraph | |
| if i < len(paragraphs) - 1: | |
| yield "\n\n" | |
| except Exception as e: | |
| logging.error(f"Error translating paragraph: {str(e)}") | |
| yield f"[Error translating paragraph: {str(e)}]\n\n" | |
| else: | |
| # For short texts, process the entire text at once | |
| try: | |
| # Only lock the actual model inference | |
| with self._lock: | |
| result = pipeline(text) | |
| translated = _extract_translation(result) | |
| yield translated | |
| except Exception as e: | |
| logging.error(f"Error translating text: {str(e)}") | |
| yield f"[Error: {str(e)}]" | |
| except Exception as e: | |
| logging.error(f"Streaming translation error: {str(e)}") | |
| yield f"Error initializing translation: {str(e)}" | |
| def tilmash_translate(input_text, src_lang, tgt_lang, max_length=512, use_gpu=None): | |
| """Main translation function with structure preservation""" | |
| try: | |
| translator = TilmashTranslator(use_gpu=use_gpu) | |
| return translator.translate(input_text, src_lang, tgt_lang, max_length) | |
| except Exception as e: | |
| logging.error(f"Translation failed: {str(e)}") | |
| return f"Translation error: {str(e)}" | |
| def tilmash_translate_streaming(input_text, src_lang, tgt_lang, max_length=512, use_gpu=None) -> Iterator[str]: | |
| """Streaming version of the translation function that yields translated sentences one by one""" | |
| try: | |
| translator = TilmashTranslator(use_gpu=use_gpu) | |
| logging.info(f"Запуск перевода с use_gpu={use_gpu}, device={translator.device}") | |
| yield from translator.translate_streaming(input_text, src_lang, tgt_lang, max_length) | |
| except Exception as e: | |
| logging.error(f"Streaming translation failed: {str(e)}") | |
| yield f"Translation error: {str(e)}" | |
| def display_tilmash_streaming_translation(text: str, src_lang: str, tgt_lang: str) -> tuple: | |
| """ | |
| Display streaming translation in a Streamlit app. | |
| Args: | |
| text: Text to translate | |
| src_lang: Source language code ('en', 'ru', 'kk') | |
| tgt_lang: Target language code ('en', 'ru', 'kk') | |
| Returns: | |
| tuple: (translated_text, needs_chunking) | |
| """ | |
| import streamlit as st | |
| if not text: | |
| return "", False | |
| # Check if text needs chunking | |
| needs_chunking = len(text) > 1000 # Roughly 250 tokens | |
| # Create placeholder for streaming output | |
| placeholder = st.empty() | |
| result = "" | |
| # Stream translation | |
| for sentence in tilmash_translate_streaming(text, src_lang, tgt_lang): | |
| result += sentence | |
| placeholder.markdown(result) | |
| return result, needs_chunking | |
| def _extract_translation(result): | |
| """Safe extraction of translation text from pipeline output""" | |
| try: | |
| if isinstance(result, list) and len(result) > 0: | |
| return result[0].get('translation_text', '').strip() | |
| return "" | |
| except Exception as e: | |
| logging.error(f"Translation extraction error: {str(e)}") | |
| return "" | |
| def _process_large_text(text, src_lang, pipeline, tokenizer, max_length): | |
| """Process long documents with structure preservation""" | |
| try: | |
| chunks_with_seps = chunk_text_with_separators( | |
| text=text, | |
| tokenizer=tokenizer, | |
| max_tokens=int(0.9 * max_length), | |
| lang='russian' if src_lang in ['ru', 'kk'] else 'english' | |
| ) | |
| except Exception as e: | |
| logging.error(f"Chunking failed: {str(e)}") | |
| return "" | |
| translations = [] | |
| prev_separator = None | |
| for chunk_idx, (chunk, separator) in enumerate(chunks_with_seps): | |
| if not chunk.strip(): | |
| translations.append(separator) | |
| continue | |
| try: | |
| # Process chunk through translation pipeline | |
| result = pipeline(chunk) | |
| translated = _extract_translation(result) | |
| # Preserve original document structure | |
| if prev_separator: | |
| translations.append(prev_separator) | |
| # Add indentation for list items and tables | |
| if _is_structured_element(chunk): | |
| translated = _preserve_structure(translated, chunk) | |
| translations.append(translated) | |
| prev_separator = separator | |
| except Exception as e: | |
| logging.error(f"Chunk {chunk_idx + 1} error: {str(e)}") | |
| translations.append(f"<<ERROR: {chunk[:50]}...>>{separator or ' '}") | |
| prev_separator = separator | |
| # Assemble final text with cleanup | |
| final_text = ''.join(translations).strip() | |
| return _postprocess_translation(final_text) | |
| def _is_structured_element(text): | |
| """Check if text contains document structure elements""" | |
| return any([ | |
| re.match(r'^\s*(\d+\.|\-|\*)\s', text), # List items | |
| re.search(r':\s*$', text) and re.search(r'[A-ZА-Я]{3,}', text), # Headers | |
| re.search(r'\|.+\|', text), # Tables | |
| re.search(r'\b(Таблица|Table)\b', text, re.IGNORECASE) # Table labels | |
| ]) | |
| def _preserve_structure(translated, original): | |
| """Maintain original formatting in translated structured elements""" | |
| # Preserve list indentation | |
| if re.match(r'^\s*(\d+\.|\-|\*)\s', original): | |
| return '\n' + translated.lstrip() | |
| # Preserve table formatting | |
| if '|' in original: | |
| return translated.replace(' | ', '|').replace('| ', '|').replace(' |', '|') | |
| return translated | |
| def _postprocess_translation(text): | |
| """Final cleanup of translated text""" | |
| # Fix list numbering | |
| text = re.sub(r'\n(\d+)\.\s*\n', r'\n\1. ', text) | |
| # Repair table formatting | |
| text = re.sub(r'(:\s*)\n(\S)', r'\1\2', text) | |
| # Normalize whitespace | |
| text = re.sub(r'([,:;])\s+', r'\1 ', text) | |
| text = re.sub(r'\s+([.!?])', r'\1', text) | |
| # Restore special characters | |
| text = text.replace('«', '"').replace('»', '"') | |
| return text |