Spaces:
Sleeping
Sleeping
from transformers import AutoProcessor | |
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq | |
import torch | |
import os | |
import base64 | |
import tempfile | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
import uvicorn | |
import deepl | |
from dotenv import load_dotenv | |
import soundfile as sf | |
import logging | |
# --- Basic Configuration --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# --- Load environment variables and initialize DeepL --- | |
load_dotenv() | |
DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY") | |
deepl_translator = None | |
if DEEPL_AUTH_KEY: | |
try: | |
deepl_translator = deepl.Translator(DEEPL_AUTH_KEY) | |
logging.info("DeepL translator initialized successfully.") | |
except Exception as e: | |
logging.error(f"Error initializing DeepL translator: {e}") | |
else: | |
logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.") | |
# --- Load Models --- | |
logging.info("Loading all models...") | |
# ASR Model | |
asr_model_id = "openai/whisper-base" | |
asr_model = None | |
asr_processor = None | |
try: | |
asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider") | |
asr_processor = AutoProcessor.from_pretrained(asr_model_id) | |
# FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids' | |
# that clashes with the latest library versions. The library both requires this attribute | |
# to exist, but also requires it to be None to avoid a conflict. | |
if hasattr(asr_model.config, 'forced_decoder_ids'): | |
logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.") | |
asr_model.config.forced_decoder_ids = None | |
if hasattr(asr_model.generation_config, 'forced_decoder_ids'): | |
logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.") | |
asr_model.generation_config.forced_decoder_ids = None | |
logging.info("ASR model and processor loaded and configured successfully.") | |
except Exception as e: | |
logging.error(f"Fatal error loading ASR model: {e}", exc_info=True) | |
# Translation Pipelines | |
from transformers import pipeline | |
translators = {} | |
try: | |
translators = { | |
"en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"), | |
"zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"), | |
"en-ja": pipeline("translation", model="staka/fugumt-en-ja"), | |
"ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"), | |
"en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"), | |
"ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"), | |
} | |
logging.info("Translation models loaded successfully.") | |
except Exception as e: | |
logging.error(f"Failed to load translation models: {e}") | |
# --- Core Logic Functions --- | |
def transcribe_audio(audio_bytes): | |
if not asr_model or not asr_processor: | |
return None, "ASR model or processor is not available." | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: | |
tmp_file.write(audio_bytes) | |
audio_path = tmp_file.name | |
audio_input, sample_rate = sf.read(audio_path) | |
if audio_input.ndim > 1: | |
audio_input = audio_input.mean(axis=1) | |
input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features | |
# By setting forced_decoder_ids to None in the config, we can now safely | |
# let the generate function handle the task without conflicts. | |
predicted_ids = asr_model.generate(input_features, task="transcribe") | |
text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
logging.info(f"ASR transcribed text: '{text}'") | |
os.remove(audio_path) | |
return text, None | |
except Exception as e: | |
logging.error(f"ASR transcription failed: {e}", exc_info=True) | |
if 'audio_path' in locals() and os.path.exists(audio_path): | |
os.remove(audio_path) | |
return None, str(e) | |
def translate_text(text, source_lang, target_lang): | |
# Priority 1: Use DeepL for specific, high-quality pairs if available | |
if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')): | |
try: | |
dl_source_lang = "ZH" if source_lang == 'zh' else "EN" | |
logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}") | |
result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA") | |
return result.text, None | |
except Exception as e: | |
logging.error(f"DeepL failed: {e}. Falling back to HF models.") | |
# Priority 2: Try direct HF translation | |
model_key = f"{source_lang}-{target_lang}" | |
translator = translators.get(model_key) | |
if translator: | |
try: | |
logging.info(f"Attempting direct HF translation for {model_key}") | |
translated_text = translator(text, max_length=512)[0]['translation_text'] | |
return translated_text, None | |
except Exception as e: | |
logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True) | |
# Don't return here, allow fallback to pivot | |
# Priority 3: Try pivot translation via English | |
if source_lang != 'en' and target_lang != 'en': | |
to_en_key = f"{source_lang}-en" | |
from_en_key = f"en-{target_lang}" | |
translator_to_en = translators.get(to_en_key) | |
translator_from_en = translators.get(from_en_key) | |
if translator_to_en and translator_from_en: | |
try: | |
logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}") | |
# Step 1: Source to English | |
english_text = translator_to_en(text, max_length=512)[0]['translation_text'] | |
logging.info(f"Pivot step (to en) result: '{english_text}'") | |
# Step 2: English to Target | |
final_text = translator_from_en(english_text, max_length=512)[0]['translation_text'] | |
logging.info(f"Pivot step (from en) result: '{final_text}'") | |
return final_text, None | |
except Exception as e: | |
logging.error(f"Pivot translation failed: {e}", exc_info=True) | |
# If all else fails | |
logging.warning(f"No translation path found for {source_lang} -> {target_lang}") | |
return None, f"No model available for {source_lang} to {target_lang}" | |
# --- FastAPI App --- | |
app = FastAPI() | |
def root(): | |
return {"status": "ok", "message": "Translator API is running."} | |
async def api_asr(request: Request): | |
try: | |
body = await request.json() | |
audio_b64 = body.get('audio_base64') | |
if not audio_b64: | |
logging.error("Request is missing 'audio_base64'") | |
return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"}) | |
audio_bytes = base64.b64decode(audio_b64) | |
text, error = transcribe_audio(audio_bytes) | |
if error: | |
logging.error(f"ASR transcription function returned an error: {error}") | |
return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"}) | |
response_data = {"text": text} | |
logging.info(f"Returning ASR response: {response_data}") | |
return JSONResponse(content=response_data) | |
except Exception as e: | |
logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True) | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
async def api_translate(request: Request): | |
try: | |
body = await request.json() | |
text = body.get('text') | |
source_lang = body.get('source_lang') | |
target_lang = body.get('target_lang') | |
if not all([text, source_lang, target_lang]): | |
return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"}) | |
translated_text, error = translate_text(text, source_lang, target_lang) | |
if error: | |
return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"}) | |
response_data = {"translated_text": translated_text} | |
logging.info(f"Returning translation response: {response_data}") | |
return JSONResponse(content=response_data) | |
except Exception as e: | |
logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True) | |
return JSONResponse(status_code=500, content={"error": str(e)}) | |
# --- Main Execution --- | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |