|
from transformers import MarianMTModel, MarianTokenizer
|
|
import torch
|
|
import threading
|
|
import queue
|
|
import time
|
|
import uuid
|
|
|
|
class Translator:
|
|
def __init__(self):
|
|
|
|
self.models = {
|
|
'en-fr': 'Helsinki-NLP/opus-mt-en-fr',
|
|
'en-es': 'Helsinki-NLP/opus-mt-en-es',
|
|
'en-de': 'Helsinki-NLP/opus-mt-en-de',
|
|
'en-hi': 'Helsinki-NLP/opus-mt-en-hi',
|
|
'fr-en': 'Helsinki-NLP/opus-mt-fr-en',
|
|
'es-en': 'Helsinki-NLP/opus-mt-es-en',
|
|
'de-en': 'Helsinki-NLP/opus-mt-de-en',
|
|
'hi-en': 'Helsinki-NLP/opus-mt-hi-en',
|
|
|
|
}
|
|
|
|
self.loaded_models = {}
|
|
self.loaded_tokenizers = {}
|
|
self.max_models_in_memory = 1
|
|
|
|
|
|
self.languages = {
|
|
'en': 'English',
|
|
'fr': 'French',
|
|
'es': 'Spanish',
|
|
'de': 'German',
|
|
'hi': 'Hindi',
|
|
|
|
}
|
|
|
|
|
|
self.translation_queue = queue.Queue()
|
|
self.translation_results = {}
|
|
self.worker_thread = threading.Thread(target=self._translation_worker, daemon=True)
|
|
self.worker_thread.start()
|
|
|
|
def get_available_languages(self):
|
|
"""Return available languages"""
|
|
return self.languages
|
|
|
|
def get_model_name(self, source_lang, target_lang):
|
|
"""Get the appropriate model name for the language pair"""
|
|
lang_pair = f"{source_lang}-{target_lang}"
|
|
return self.models.get(lang_pair)
|
|
|
|
def load_model(self, model_name):
|
|
"""Load model and tokenizer if not already loaded, with memory management"""
|
|
|
|
if model_name in self.loaded_models:
|
|
return self.loaded_models[model_name], self.loaded_tokenizers[model_name]
|
|
|
|
|
|
if len(self.loaded_models) >= self.max_models_in_memory:
|
|
|
|
print(f"Memory limit reached. Clearing models...")
|
|
self.loaded_models = {}
|
|
self.loaded_tokenizers = {}
|
|
|
|
import gc
|
|
gc.collect()
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
|
|
|
|
|
print(f"Loading model: {model_name}")
|
|
try:
|
|
|
|
self.loaded_tokenizers[model_name] = MarianTokenizer.from_pretrained(model_name)
|
|
self.loaded_models[model_name] = MarianMTModel.from_pretrained(
|
|
model_name,
|
|
low_cpu_mem_usage=True,
|
|
torch_dtype=torch.float16,
|
|
local_files_only=False,
|
|
force_download=False
|
|
)
|
|
|
|
|
|
import gc
|
|
gc.collect()
|
|
|
|
return self.loaded_models[model_name], self.loaded_tokenizers[model_name]
|
|
except Exception as e:
|
|
print(f"Error loading model {model_name}: {e}")
|
|
|
|
import traceback
|
|
print(f"Detailed error: {traceback.format_exc()}")
|
|
raise
|
|
|
|
def _translation_worker(self):
|
|
"""Background worker that processes translation requests"""
|
|
while True:
|
|
try:
|
|
|
|
task_id, text, source_lang, target_lang = self.translation_queue.get()
|
|
|
|
|
|
model_name = self.get_model_name(source_lang, target_lang)
|
|
|
|
if not model_name:
|
|
self.translation_results[task_id] = f"Translation not available for {source_lang} to {target_lang}"
|
|
else:
|
|
try:
|
|
model, tokenizer = self.load_model(model_name)
|
|
|
|
|
|
max_length = 512
|
|
if len(text) > max_length * 2:
|
|
|
|
chunks = [text[i:i + max_length * 2] for i in range(0, len(text), max_length * 2)]
|
|
translated_chunks = []
|
|
|
|
for chunk in chunks:
|
|
|
|
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
|
|
|
|
|
with torch.no_grad():
|
|
translated = model.generate(
|
|
**inputs,
|
|
num_beams=2,
|
|
early_stopping=True,
|
|
max_length=max_length
|
|
)
|
|
|
|
chunk_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
|
translated_chunks.append(chunk_text)
|
|
|
|
|
|
self.translation_results[task_id] = " ".join(translated_chunks)
|
|
else:
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
|
|
|
|
|
with torch.no_grad():
|
|
translated = model.generate(
|
|
**inputs,
|
|
num_beams=2,
|
|
early_stopping=True,
|
|
max_length=max_length * 2
|
|
)
|
|
|
|
|
|
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
|
self.translation_results[task_id] = translated_text[0]
|
|
|
|
|
|
import gc
|
|
gc.collect()
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
|
|
|
except Exception as e:
|
|
print(f"Translation error: {e}")
|
|
self.translation_results[task_id] = f"Error translating text: {str(e)[:100]}"
|
|
|
|
|
|
self.translation_queue.task_done()
|
|
except Exception as e:
|
|
print(f"Error in translation worker: {e}")
|
|
self.translation_results[task_id] = "Server error occurred during translation"
|
|
self.translation_queue.task_done()
|
|
|
|
continue
|
|
|
|
def translate(self, text, source_lang, target_lang):
|
|
"""Translate text from source language to target language"""
|
|
|
|
task_id = str(uuid.uuid4())
|
|
|
|
|
|
self.translation_queue.put((task_id, text, source_lang, target_lang))
|
|
|
|
|
|
return {"task_id": task_id, "status": "processing"}
|
|
|
|
def get_translation_result(self, task_id):
|
|
"""Get the result of a translation task by its ID"""
|
|
if task_id in self.translation_results:
|
|
result = self.translation_results[task_id]
|
|
|
|
del self.translation_results[task_id]
|
|
return {"status": "completed", "translation": result}
|
|
else:
|
|
return {"status": "processing"}
|
|
|