|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request |
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from typing import List, Optional |
|
import shutil |
|
import os |
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import traceback |
|
|
|
|
|
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
TEMPLATE_DIR = os.path.join(os.path.dirname(BASE_DIR), "templates") |
|
STATIC_DIR = os.path.join(os.path.dirname(BASE_DIR), "static") |
|
UPLOAD_DIR = "/app/uploads" |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") |
|
|
|
|
|
templates = Jinja2Templates(directory=TEMPLATE_DIR) |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "google/flan-t5-small" |
|
CACHE_DIR = "/app/.cache" |
|
model = None |
|
tokenizer = None |
|
|
|
try: |
|
print("--- Loading Model ---") |
|
print(f"Loading tokenizer for {MODEL_NAME} using AutoTokenizer...") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) |
|
print(f"Loading model for {MODEL_NAME} using AutoModelForSeq2SeqLM...") |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) |
|
print("--- Model Loaded Successfully ---") |
|
except Exception as e: |
|
print(f"--- ERROR Loading Model ---") |
|
print(f"Error loading model or tokenizer {MODEL_NAME}: {e}") |
|
traceback.print_exc() |
|
|
|
|
|
|
|
def translate_text_internal(text: str, source_lang: str, target_lang: str = "ar") -> str: |
|
"""Internal function to handle text translation using the loaded model via prompting.""" |
|
if model is None or tokenizer is None: |
|
|
|
raise HTTPException(status_code=503, detail="Translation service is unavailable (model not loaded).") |
|
|
|
|
|
|
|
language_map = { |
|
"en": "English", |
|
"fr": "French", |
|
"es": "Spanish", |
|
"de": "German", |
|
"zh": "Chinese", |
|
"ru": "Russian", |
|
"ja": "Japanese", |
|
"hi": "Hindi", |
|
"pt": "Portuguese", |
|
"tr": "Turkish", |
|
"ko": "Korean", |
|
"it": "Italian" |
|
|
|
} |
|
|
|
|
|
source_lang_name = language_map.get(source_lang, source_lang) |
|
|
|
|
|
|
|
prompt = f"""Translate the following {source_lang_name} text into Modern Standard Arabic (Fusha). |
|
Focus on conveying the meaning elegantly using proper Balagha (Arabic eloquence). |
|
Adapt any cultural references or idioms appropriately rather than translating literally. |
|
Ensure the translation reads naturally to a native Arabic speaker. |
|
|
|
Text to translate: |
|
{text}""" |
|
|
|
print(f"Translation Request - Source Lang: {source_lang} ({source_lang_name}), Target Lang: {target_lang}") |
|
print(f"Using Enhanced Prompt for Balagha and Cultural Sensitivity") |
|
|
|
|
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=512, |
|
num_beams=5, |
|
length_penalty=1.0, |
|
top_k=50, |
|
top_p=0.95, |
|
early_stopping=True |
|
) |
|
|
|
|
|
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
print(f"Raw Translation Output: {translated_text}") |
|
return translated_text |
|
|
|
except Exception as e: |
|
print(f"Error during model generation: {e}") |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"Translation failed during generation: {e}") |
|
|
|
|
|
async def extract_text_from_file(file: UploadFile) -> str: |
|
"""Extracts text content from various file types.""" |
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
safe_filename = os.path.basename(file.filename) |
|
temp_file_path = os.path.join(UPLOAD_DIR, f"temp_{safe_filename}") |
|
print(f"Attempting to save uploaded file to: {temp_file_path}") |
|
extracted_text = "" |
|
|
|
try: |
|
|
|
|
|
|
|
with open(temp_file_path, "wb") as buffer: |
|
content = await file.read() |
|
buffer.write(content) |
|
print(f"File saved successfully to: {temp_file_path}") |
|
|
|
|
|
file_extension = os.path.splitext(safe_filename)[1].lower() |
|
|
|
if file_extension == '.txt': |
|
with open(temp_file_path, 'r', encoding='utf-8') as f: |
|
extracted_text = f.read() |
|
elif file_extension == '.docx': |
|
try: |
|
import docx |
|
doc = docx.Document(temp_file_path) |
|
extracted_text = '\\n'.join([para.text for para in doc.paragraphs]) |
|
except ImportError: |
|
raise HTTPException(status_code=501, detail="DOCX processing requires 'python-docx' library, which is not installed.") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error reading DOCX file: {e}") |
|
elif file_extension == '.pdf': |
|
try: |
|
import fitz |
|
doc = fitz.open(temp_file_path) |
|
extracted_text = "" |
|
for page in doc: |
|
extracted_text += page.get_text() |
|
doc.close() |
|
except ImportError: |
|
raise HTTPException(status_code=501, detail="PDF processing requires 'PyMuPDF' library, which is not installed.") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error reading PDF file: {e}") |
|
|
|
|
|
|
|
else: |
|
raise HTTPException(status_code=400, detail=f"Unsupported file type: {file_extension}") |
|
|
|
print(f"Extracted text length: {len(extracted_text)}") |
|
return extracted_text |
|
|
|
except IOError as e: |
|
print(f"IOError saving/reading file {temp_file_path}: {e}") |
|
|
|
if e.errno == 13: |
|
raise HTTPException(status_code=500, detail=f"Permission denied writing to {temp_file_path}. Check container permissions for {UPLOAD_DIR}.") |
|
raise HTTPException(status_code=500, detail=f"Error saving/accessing uploaded file: {e}") |
|
except HTTPException as e: |
|
|
|
raise e |
|
except Exception as e: |
|
print(f"Error processing file {file.filename}: {e}") |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred processing the document: {e}") |
|
finally: |
|
|
|
if os.path.exists(temp_file_path): |
|
try: |
|
os.remove(temp_file_path) |
|
print(f"Temporary file removed: {temp_file_path}") |
|
except OSError as e: |
|
|
|
print(f"Error removing temporary file {temp_file_path}: {e}") |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def read_root(request: Request): |
|
"""Serves the main HTML page.""" |
|
|
|
if not os.path.exists(TEMPLATE_DIR): |
|
raise HTTPException(status_code=500, detail=f"Template directory not found at {TEMPLATE_DIR}") |
|
if not os.path.exists(os.path.join(TEMPLATE_DIR, "index.html")): |
|
raise HTTPException(status_code=500, detail=f"index.html not found in {TEMPLATE_DIR}") |
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
@app.post("/translate/text") |
|
async def translate_text_endpoint( |
|
text: str = Form(...), |
|
source_lang: str = Form(...), |
|
target_lang: str = Form("ar") |
|
): |
|
"""Translates direct text input.""" |
|
if not text: |
|
raise HTTPException(status_code=400, detail="No text provided for translation.") |
|
|
|
|
|
|
|
|
|
|
|
if target_lang != "ar": |
|
raise HTTPException(status_code=400, detail="Currently, only translation to Arabic (ar) is supported via this endpoint.") |
|
|
|
try: |
|
|
|
actual_source_lang = source_lang |
|
|
|
|
|
|
|
translated_text = translate_text_internal(text, actual_source_lang, target_lang) |
|
return JSONResponse(content={"translated_text": translated_text, "source_lang": actual_source_lang}) |
|
except HTTPException as http_exc: |
|
|
|
raise http_exc |
|
except Exception as e: |
|
print(f"Unexpected error in /translate/text: {e}") |
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during text translation: {e}") |
|
|
|
|
|
@app.post("/translate/document") |
|
async def translate_document_endpoint( |
|
file: UploadFile = File(...), |
|
source_lang: str = Form(...), |
|
target_lang: str = Form("ar") |
|
): |
|
"""Translates text extracted from an uploaded document.""" |
|
|
|
|
|
|
|
|
|
|
|
if target_lang != "ar": |
|
raise HTTPException(status_code=400, detail="Currently, only document translation to Arabic (ar) is supported.") |
|
|
|
|
|
if not os.path.exists(UPLOAD_DIR): |
|
try: |
|
os.makedirs(UPLOAD_DIR) |
|
except OSError as e: |
|
raise HTTPException(status_code=500, detail=f"Could not create upload directory: {e}") |
|
|
|
|
|
temp_file_path = os.path.join(UPLOAD_DIR, f"temp_{file.filename}") |
|
|
|
try: |
|
|
|
with open(temp_file_path, "wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
extracted_text = await extract_text_from_file(file) |
|
|
|
|
|
if not extracted_text: |
|
|
|
|
|
if os.path.exists(temp_file_path): |
|
os.remove(temp_file_path) |
|
raise HTTPException(status_code=400, detail="Could not extract any text from the document.") |
|
|
|
|
|
actual_source_lang = source_lang |
|
|
|
|
|
|
|
|
|
translated_text = translate_text_internal(extracted_text, actual_source_lang, target_lang) |
|
|
|
|
|
if os.path.exists(temp_file_path): |
|
os.remove(temp_file_path) |
|
|
|
return JSONResponse(content={ |
|
"original_filename": file.filename, |
|
"detected_source_lang": actual_source_lang, |
|
"translated_text": translated_text |
|
}) |
|
|
|
except HTTPException as http_exc: |
|
|
|
if os.path.exists(temp_file_path): |
|
try: |
|
os.remove(temp_file_path) |
|
except: |
|
pass |
|
raise http_exc |
|
except Exception as e: |
|
|
|
if os.path.exists(temp_file_path): |
|
try: |
|
os.remove(temp_file_path) |
|
except: |
|
pass |
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred processing the document: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
|
|
print(f"Template Directory: {TEMPLATE_DIR}") |
|
print(f"Static Directory: {STATIC_DIR}") |
|
print(f"Upload Directory: {UPLOAD_DIR}") |
|
|
|
if not os.path.exists(TEMPLATE_DIR): os.makedirs(TEMPLATE_DIR) |
|
if not os.path.exists(STATIC_DIR): os.makedirs(STATIC_DIR) |
|
if not os.path.exists(UPLOAD_DIR): os.makedirs(UPLOAD_DIR) |
|
|
|
if not os.path.exists(os.path.join(TEMPLATE_DIR, "index.html")): |
|
with open(os.path.join(TEMPLATE_DIR, "index.html"), "w") as f: |
|
f.write("<html><body><h1>Placeholder Frontend</h1></body></html>") |
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |
|
|