AKIRA
Fix: Move HF_HOME env var to Dockerfile for proper initialization
ea05603
from transformers import AutoProcessor
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
import torch
import os
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
import deepl
from dotenv import load_dotenv
import soundfile as sf
import logging
# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Load environment variables and initialize DeepL ---
load_dotenv()
DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY")
deepl_translator = None
if DEEPL_AUTH_KEY:
try:
deepl_translator = deepl.Translator(DEEPL_AUTH_KEY)
logging.info("DeepL translator initialized successfully.")
except Exception as e:
logging.error(f"Error initializing DeepL translator: {e}")
else:
logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.")
# --- Load Models ---
logging.info("Loading all models...")
# ASR Model
asr_model_id = "openai/whisper-base"
asr_model = None
asr_processor = None
try:
asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider")
asr_processor = AutoProcessor.from_pretrained(asr_model_id)
# FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids'
# that clashes with the latest library versions. The library both requires this attribute
# to exist, but also requires it to be None to avoid a conflict.
if hasattr(asr_model.config, 'forced_decoder_ids'):
logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.")
asr_model.config.forced_decoder_ids = None
if hasattr(asr_model.generation_config, 'forced_decoder_ids'):
logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.")
asr_model.generation_config.forced_decoder_ids = None
logging.info("ASR model and processor loaded and configured successfully.")
except Exception as e:
logging.error(f"Fatal error loading ASR model: {e}", exc_info=True)
# Translation Pipelines
from transformers import pipeline
translators = {}
try:
translators = {
"en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"),
"zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"),
"en-ja": pipeline("translation", model="staka/fugumt-en-ja"),
"ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"),
"en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"),
"ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"),
}
logging.info("Translation models loaded successfully.")
except Exception as e:
logging.error(f"Failed to load translation models: {e}")
# --- Core Logic Functions ---
def transcribe_audio(audio_bytes):
if not asr_model or not asr_processor:
return None, "ASR model or processor is not available."
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_file.write(audio_bytes)
audio_path = tmp_file.name
audio_input, sample_rate = sf.read(audio_path)
if audio_input.ndim > 1:
audio_input = audio_input.mean(axis=1)
input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features
# By setting forced_decoder_ids to None in the config, we can now safely
# let the generate function handle the task without conflicts.
predicted_ids = asr_model.generate(input_features, task="transcribe")
text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
logging.info(f"ASR transcribed text: '{text}'")
os.remove(audio_path)
return text, None
except Exception as e:
logging.error(f"ASR transcription failed: {e}", exc_info=True)
if 'audio_path' in locals() and os.path.exists(audio_path):
os.remove(audio_path)
return None, str(e)
def translate_text(text, source_lang, target_lang):
# Priority 1: Use DeepL for specific, high-quality pairs if available
if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')):
try:
dl_source_lang = "ZH" if source_lang == 'zh' else "EN"
logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}")
result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA")
return result.text, None
except Exception as e:
logging.error(f"DeepL failed: {e}. Falling back to HF models.")
# Priority 2: Try direct HF translation
model_key = f"{source_lang}-{target_lang}"
translator = translators.get(model_key)
if translator:
try:
logging.info(f"Attempting direct HF translation for {model_key}")
translated_text = translator(text, max_length=512)[0]['translation_text']
return translated_text, None
except Exception as e:
logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True)
# Don't return here, allow fallback to pivot
# Priority 3: Try pivot translation via English
if source_lang != 'en' and target_lang != 'en':
to_en_key = f"{source_lang}-en"
from_en_key = f"en-{target_lang}"
translator_to_en = translators.get(to_en_key)
translator_from_en = translators.get(from_en_key)
if translator_to_en and translator_from_en:
try:
logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}")
# Step 1: Source to English
english_text = translator_to_en(text, max_length=512)[0]['translation_text']
logging.info(f"Pivot step (to en) result: '{english_text}'")
# Step 2: English to Target
final_text = translator_from_en(english_text, max_length=512)[0]['translation_text']
logging.info(f"Pivot step (from en) result: '{final_text}'")
return final_text, None
except Exception as e:
logging.error(f"Pivot translation failed: {e}", exc_info=True)
# If all else fails
logging.warning(f"No translation path found for {source_lang} -> {target_lang}")
return None, f"No model available for {source_lang} to {target_lang}"
# --- FastAPI App ---
app = FastAPI()
@app.get("/")
def root():
return {"status": "ok", "message": "Translator API is running."}
@app.post("/api/asr")
async def api_asr(request: Request):
try:
body = await request.json()
audio_b64 = body.get('audio_base64')
if not audio_b64:
logging.error("Request is missing 'audio_base64'")
return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"})
audio_bytes = base64.b64decode(audio_b64)
text, error = transcribe_audio(audio_bytes)
if error:
logging.error(f"ASR transcription function returned an error: {error}")
return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"})
response_data = {"text": text}
logging.info(f"Returning ASR response: {response_data}")
return JSONResponse(content=response_data)
except Exception as e:
logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True)
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/api/translate")
async def api_translate(request: Request):
try:
body = await request.json()
text = body.get('text')
source_lang = body.get('source_lang')
target_lang = body.get('target_lang')
if not all([text, source_lang, target_lang]):
return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"})
translated_text, error = translate_text(text, source_lang, target_lang)
if error:
return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"})
response_data = {"translated_text": translated_text}
logging.info(f"Returning translation response: {response_data}")
return JSONResponse(content=response_data)
except Exception as e:
logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True)
return JSONResponse(status_code=500, content={"error": str(e)})
# --- Main Execution ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)