from transformers import AutoProcessor from optimum.onnxruntime import ORTModelForSpeechSeq2Seq import torch import os import base64 import tempfile from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import uvicorn import deepl from dotenv import load_dotenv import soundfile as sf import logging # --- Basic Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Load environment variables and initialize DeepL --- load_dotenv() DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY") deepl_translator = None if DEEPL_AUTH_KEY: try: deepl_translator = deepl.Translator(DEEPL_AUTH_KEY) logging.info("DeepL translator initialized successfully.") except Exception as e: logging.error(f"Error initializing DeepL translator: {e}") else: logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.") # --- Load Models --- logging.info("Loading all models...") # ASR Model asr_model_id = "openai/whisper-base" asr_model = None asr_processor = None try: asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider") asr_processor = AutoProcessor.from_pretrained(asr_model_id) # FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids' # that clashes with the latest library versions. The library both requires this attribute # to exist, but also requires it to be None to avoid a conflict. if hasattr(asr_model.config, 'forced_decoder_ids'): logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.") asr_model.config.forced_decoder_ids = None if hasattr(asr_model.generation_config, 'forced_decoder_ids'): logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.") asr_model.generation_config.forced_decoder_ids = None logging.info("ASR model and processor loaded and configured successfully.") except Exception as e: logging.error(f"Fatal error loading ASR model: {e}", exc_info=True) # Translation Pipelines from transformers import pipeline translators = {} try: translators = { "en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"), "zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"), "en-ja": pipeline("translation", model="staka/fugumt-en-ja"), "ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"), "en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"), "ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"), } logging.info("Translation models loaded successfully.") except Exception as e: logging.error(f"Failed to load translation models: {e}") # --- Core Logic Functions --- def transcribe_audio(audio_bytes): if not asr_model or not asr_processor: return None, "ASR model or processor is not available." try: with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: tmp_file.write(audio_bytes) audio_path = tmp_file.name audio_input, sample_rate = sf.read(audio_path) if audio_input.ndim > 1: audio_input = audio_input.mean(axis=1) input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features # By setting forced_decoder_ids to None in the config, we can now safely # let the generate function handle the task without conflicts. predicted_ids = asr_model.generate(input_features, task="transcribe") text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] logging.info(f"ASR transcribed text: '{text}'") os.remove(audio_path) return text, None except Exception as e: logging.error(f"ASR transcription failed: {e}", exc_info=True) if 'audio_path' in locals() and os.path.exists(audio_path): os.remove(audio_path) return None, str(e) def translate_text(text, source_lang, target_lang): # Priority 1: Use DeepL for specific, high-quality pairs if available if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')): try: dl_source_lang = "ZH" if source_lang == 'zh' else "EN" logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}") result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA") return result.text, None except Exception as e: logging.error(f"DeepL failed: {e}. Falling back to HF models.") # Priority 2: Try direct HF translation model_key = f"{source_lang}-{target_lang}" translator = translators.get(model_key) if translator: try: logging.info(f"Attempting direct HF translation for {model_key}") translated_text = translator(text, max_length=512)[0]['translation_text'] return translated_text, None except Exception as e: logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True) # Don't return here, allow fallback to pivot # Priority 3: Try pivot translation via English if source_lang != 'en' and target_lang != 'en': to_en_key = f"{source_lang}-en" from_en_key = f"en-{target_lang}" translator_to_en = translators.get(to_en_key) translator_from_en = translators.get(from_en_key) if translator_to_en and translator_from_en: try: logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}") # Step 1: Source to English english_text = translator_to_en(text, max_length=512)[0]['translation_text'] logging.info(f"Pivot step (to en) result: '{english_text}'") # Step 2: English to Target final_text = translator_from_en(english_text, max_length=512)[0]['translation_text'] logging.info(f"Pivot step (from en) result: '{final_text}'") return final_text, None except Exception as e: logging.error(f"Pivot translation failed: {e}", exc_info=True) # If all else fails logging.warning(f"No translation path found for {source_lang} -> {target_lang}") return None, f"No model available for {source_lang} to {target_lang}" # --- FastAPI App --- app = FastAPI() @app.get("/") def root(): return {"status": "ok", "message": "Translator API is running."} @app.post("/api/asr") async def api_asr(request: Request): try: body = await request.json() audio_b64 = body.get('audio_base64') if not audio_b64: logging.error("Request is missing 'audio_base64'") return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"}) audio_bytes = base64.b64decode(audio_b64) text, error = transcribe_audio(audio_bytes) if error: logging.error(f"ASR transcription function returned an error: {error}") return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"}) response_data = {"text": text} logging.info(f"Returning ASR response: {response_data}") return JSONResponse(content=response_data) except Exception as e: logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True) return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/api/translate") async def api_translate(request: Request): try: body = await request.json() text = body.get('text') source_lang = body.get('source_lang') target_lang = body.get('target_lang') if not all([text, source_lang, target_lang]): return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"}) translated_text, error = translate_text(text, source_lang, target_lang) if error: return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"}) response_data = {"translated_text": translated_text} logging.info(f"Returning translation response: {response_data}") return JSONResponse(content=response_data) except Exception as e: logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True) return JSONResponse(status_code=500, content={"error": str(e)}) # --- Main Execution --- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)