AKIRA
Fix: Set HF_HOME to /app/.cache to resolve PermissionError
1f4061f
raw
history blame
8.85 kB
from transformers import AutoProcessor
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
import torch
import os
os.environ['HF_HOME'] = '/app/.cache'
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
import deepl
from dotenv import load_dotenv
import soundfile as sf
import logging
# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Load environment variables and initialize DeepL ---
load_dotenv()
DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY")
deepl_translator = None
if DEEPL_AUTH_KEY:
try:
deepl_translator = deepl.Translator(DEEPL_AUTH_KEY)
logging.info("DeepL translator initialized successfully.")
except Exception as e:
logging.error(f"Error initializing DeepL translator: {e}")
else:
logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.")
# --- Load Models ---
logging.info("Loading all models...")
# ASR Model
asr_model_id = "openai/whisper-base"
asr_model = None
asr_processor = None
try:
asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider")
asr_processor = AutoProcessor.from_pretrained(asr_model_id)
# FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids'
# that clashes with the latest library versions. The library both requires this attribute
# to exist, but also requires it to be None to avoid a conflict.
if hasattr(asr_model.config, 'forced_decoder_ids'):
logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.")
asr_model.config.forced_decoder_ids = None
if hasattr(asr_model.generation_config, 'forced_decoder_ids'):
logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.")
asr_model.generation_config.forced_decoder_ids = None
logging.info("ASR model and processor loaded and configured successfully.")
except Exception as e:
logging.error(f"Fatal error loading ASR model: {e}", exc_info=True)
# Translation Pipelines
from transformers import pipeline
translators = {}
try:
translators = {
"en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"),
"zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"),
"en-ja": pipeline("translation", model="staka/fugumt-en-ja"),
"ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"),
"en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"),
"ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"),
}
logging.info("Translation models loaded successfully.")
except Exception as e:
logging.error(f"Failed to load translation models: {e}")
# --- Core Logic Functions ---
def transcribe_audio(audio_bytes):
if not asr_model or not asr_processor:
return None, "ASR model or processor is not available."
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
tmp_file.write(audio_bytes)
audio_path = tmp_file.name
audio_input, sample_rate = sf.read(audio_path)
if audio_input.ndim > 1:
audio_input = audio_input.mean(axis=1)
input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features
# By setting forced_decoder_ids to None in the config, we can now safely
# let the generate function handle the task without conflicts.
predicted_ids = asr_model.generate(input_features, task="transcribe")
text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
logging.info(f"ASR transcribed text: '{text}'")
os.remove(audio_path)
return text, None
except Exception as e:
logging.error(f"ASR transcription failed: {e}", exc_info=True)
if 'audio_path' in locals() and os.path.exists(audio_path):
os.remove(audio_path)
return None, str(e)
def translate_text(text, source_lang, target_lang):
# Priority 1: Use DeepL for specific, high-quality pairs if available
if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')):
try:
dl_source_lang = "ZH" if source_lang == 'zh' else "EN"
logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}")
result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA")
return result.text, None
except Exception as e:
logging.error(f"DeepL failed: {e}. Falling back to HF models.")
# Priority 2: Try direct HF translation
model_key = f"{source_lang}-{target_lang}"
translator = translators.get(model_key)
if translator:
try:
logging.info(f"Attempting direct HF translation for {model_key}")
translated_text = translator(text, max_length=512)[0]['translation_text']
return translated_text, None
except Exception as e:
logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True)
# Don't return here, allow fallback to pivot
# Priority 3: Try pivot translation via English
if source_lang != 'en' and target_lang != 'en':
to_en_key = f"{source_lang}-en"
from_en_key = f"en-{target_lang}"
translator_to_en = translators.get(to_en_key)
translator_from_en = translators.get(from_en_key)
if translator_to_en and translator_from_en:
try:
logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}")
# Step 1: Source to English
english_text = translator_to_en(text, max_length=512)[0]['translation_text']
logging.info(f"Pivot step (to en) result: '{english_text}'")
# Step 2: English to Target
final_text = translator_from_en(english_text, max_length=512)[0]['translation_text']
logging.info(f"Pivot step (from en) result: '{final_text}'")
return final_text, None
except Exception as e:
logging.error(f"Pivot translation failed: {e}", exc_info=True)
# If all else fails
logging.warning(f"No translation path found for {source_lang} -> {target_lang}")
return None, f"No model available for {source_lang} to {target_lang}"
# --- FastAPI App ---
app = FastAPI()
@app.get("/")
def root():
return {"status": "ok", "message": "Translator API is running."}
@app.post("/api/asr")
async def api_asr(request: Request):
try:
body = await request.json()
audio_b64 = body.get('audio_base64')
if not audio_b64:
logging.error("Request is missing 'audio_base64'")
return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"})
audio_bytes = base64.b64decode(audio_b64)
text, error = transcribe_audio(audio_bytes)
if error:
logging.error(f"ASR transcription function returned an error: {error}")
return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"})
response_data = {"text": text}
logging.info(f"Returning ASR response: {response_data}")
return JSONResponse(content=response_data)
except Exception as e:
logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True)
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/api/translate")
async def api_translate(request: Request):
try:
body = await request.json()
text = body.get('text')
source_lang = body.get('source_lang')
target_lang = body.get('target_lang')
if not all([text, source_lang, target_lang]):
return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"})
translated_text, error = translate_text(text, source_lang, target_lang)
if error:
return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"})
response_data = {"translated_text": translated_text}
logging.info(f"Returning translation response: {response_data}")
return JSONResponse(content=response_data)
except Exception as e:
logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True)
return JSONResponse(status_code=500, content={"error": str(e)})
# --- Main Execution ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)