from transformers import AutoProcessor from optimum.onnxruntime import ORTModelForSpeechSeq2Seq import torch import os os.environ['HF_HOME'] = '/app/.cache' import base64 import tempfile from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn import deepl from dotenv import load_dotenv import soundfile as sf import logging # --- Basic Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Load environment variables and initialize DeepL --- load_dotenv() DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY") deepl_translator = None if DEEPL_AUTH_KEY: try: deepl_translator = deepl.Translator(DEEPL_AUTH_KEY) logging.info("DeepL translator initialized successfully.") except Exception as e: logging.error(f"Error initializing DeepL translator: {e}") else: logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.") # --- Load Models --- logging.info("Loading all models...") # ASR Model asr_model_id = "openai/whisper-base" asr_model = None asr_processor = None try: asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider") asr_processor = AutoProcessor.from_pretrained(asr_model_id) # FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids' # that clashes with the latest library versions. The library both requires this attribute # to exist, but also requires it to be None to avoid a conflict. if hasattr(asr_model.config, 'forced_decoder_ids'): logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.") asr_model.config.forced_decoder_ids = None if hasattr(asr_model.generation_config, 'forced_decoder_ids'): logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.") asr_model.generation_config.forced_decoder_ids = None logging.info("ASR model and processor loaded and configured successfully.") except Exception as e: logging.error(f"Fatal error loading ASR model: {e}", exc_info=True) # Translation Pipelines from transformers import pipeline translators = {} try: translators = { "en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"), "zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"), "en-ja": pipeline("translation", model="staka/fugumt-en-ja"), "ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"), "en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"), "ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"), } logging.info("Translation models loaded successfully.") except Exception as e: logging.error(f"Failed to load translation models: {e}") # --- Core Logic Functions --- def transcribe_audio(audio_bytes): if not asr_model or not asr_processor: return None, "ASR model or processor is not available." try: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: tmp_file.write(audio_bytes) audio_path = tmp_file.name audio_input, sample_rate = sf.read(audio_path) if audio_input.ndim > 1: audio_input = audio_input.mean(axis=1) input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features # By setting forced_decoder_ids to None in the config, we can now safely # let the generate function handle the task without conflicts. predicted_ids = asr_model.generate(input_features, task="transcribe") text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] logging.info(f"ASR transcribed text: '{text}'") os.remove(audio_path) return text, None except Exception as e: logging.error(f"ASR transcription failed: {e}", exc_info=True) if 'audio_path' in locals() and os.path.exists(audio_path): os.remove(audio_path) return None, str(e) def translate_text(text, source_lang, target_lang): # Priority 1: Use DeepL for specific, high-quality pairs if available if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')): try: dl_source_lang = "ZH" if source_lang == 'zh' else "EN" logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}") result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA") return result.text, None except Exception as e: logging.error(f"DeepL failed: {e}. Falling back to HF models.") # Priority 2: Try direct HF translation model_key = f"{source_lang}-{target_lang}" translator = translators.get(model_key) if translator: try: logging.info(f"Attempting direct HF translation for {model_key}") translated_text = translator(text, max_length=512)[0]['translation_text'] return translated_text, None except Exception as e: logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True) # Don't return here, allow fallback to pivot # Priority 3: Try pivot translation via English if source_lang != 'en' and target_lang != 'en': to_en_key = f"{source_lang}-en" from_en_key = f"en-{target_lang}" translator_to_en = translators.get(to_en_key) translator_from_en = translators.get(from_en_key) if translator_to_en and translator_from_en: try: logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}") # Step 1: Source to English english_text = translator_to_en(text, max_length=512)[0]['translation_text'] logging.info(f"Pivot step (to en) result: '{english_text}'") # Step 2: English to Target final_text = translator_from_en(english_text, max_length=512)[0]['translation_text'] logging.info(f"Pivot step (from en) result: '{final_text}'") return final_text, None except Exception as e: logging.error(f"Pivot translation failed: {e}", exc_info=True) # If all else fails logging.warning(f"No translation path found for {source_lang} -> {target_lang}") return None, f"No model available for {source_lang} to {target_lang}" # --- FastAPI App --- app = FastAPI() @app.get("/") def root(): return {"status": "ok", "message": "Translator API is running."} @app.post("/api/asr") async def api_asr(request: Request): try: body = await request.json() audio_b64 = body.get('audio_base64') if not audio_b64: logging.error("Request is missing 'audio_base64'") return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"}) audio_bytes = base64.b64decode(audio_b64) text, error = transcribe_audio(audio_bytes) if error: logging.error(f"ASR transcription function returned an error: {error}") return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"}) response_data = {"text": text} logging.info(f"Returning ASR response: {response_data}") return JSONResponse(content=response_data) except Exception as e: logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True) return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/api/translate") async def api_translate(request: Request): try: body = await request.json() text = body.get('text') source_lang = body.get('source_lang') target_lang = body.get('target_lang') if not all([text, source_lang, target_lang]): return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"}) translated_text, error = translate_text(text, source_lang, target_lang) if error: return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"}) response_data = {"translated_text": translated_text} logging.info(f"Returning translation response: {response_data}") return JSONResponse(content=response_data) except Exception as e: logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True) return JSONResponse(status_code=500, content={"error": str(e)}) # --- Main Execution --- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)