import os import sys import time import threading import tempfile import ctypes import gc from typing import Optional from fastapi import FastAPI, Depends, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel from huggingface_hub import hf_hub_download from torch import no_grad, package from pydub import AudioSegment import uvicorn from accentor import accentification, stress_replace_and_shift import argparse from passlib.context import CryptContext app = FastAPI() # Set environment variable for Hugging Face cache directory os.environ["HF_HOME"] = "/app/.cache" tts_kwargs = { "speaker_name": "uk", "language_name": "uk", } oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") class Auth(BaseModel): api_key: str password: str # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def get_password_hash(password): return pwd_context.hash(password) def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) api_key = os.getenv("XCHE_API_KEY") password = os.getenv("XCHE_PASSWORD") fake_data_db = { api_key: { "api_key": api_key, "password": get_password_hash(password) # Pre-hashed password } } def get_api_key(db, api_key: str): if api_key in db: api_dict = db[api_key] return Auth(**api_dict) def authenticate(fake_db, api_key: str, password: str): api_key = get_api_key(fake_db, api_key) if not api_key: return False if not verify_password(password, api_key.password): return False return api_key @app.post("/token") async def login(form_data: OAuth2PasswordRequestForm = Depends()): api_key = authenticate(fake_data_db, form_data.api_key, form_data.password) if not api_key: raise HTTPException( status_code=400, detail="Incorrect API KEY or Password", headers={"WWW-Authenticate": "Bearer"}, ) return {"access_token": api_key.api_key, "token_type": "bearer"} def check_api_token(token: str = Depends(oauth2_scheme)): api_key = get_api_key(fake_data_db, token) if not api_key: raise HTTPException(status_code=403, detail="Invalid or missing API KEY") return api_key def trim_memory(): libc = ctypes.CDLL("libc.so.6") libc.malloc_trim(0) gc.collect() def init_models(): models = {} model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt") importer = package.PackageImporter(model_path) models["lada"] = importer.load_pickle("tts_models", "model") return models @app.post("/create_audio") async def tts(request: str, api_key: Auth = Depends(check_api_token)): accented_text = accentification(request) plussed_text = stress_replace_and_shift(accented_text) synt = models["lada"] with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp: with no_grad(): wav_data = synt.tts(plussed_text, **tts_kwargs) synt.save_wav(wav_data, wav_fp) threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start() return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"}) @app.get("/download_audio") async def download_audio(audio_path: str): return FileResponse(audio_path, media_type='audio/wav') models = init_models() def delete_file_after_delay(file_path: str, delay: int): time.sleep(delay) if os.path.exists(file_path): os.remove(file_path) class ArgParser(argparse.ArgumentParser): def __init__(self, *args, **kwargs): super(ArgParser, self).__init__(*args, **kwargs) self.add_argument( "-s", "--server", type=str, default="0.0.0.0", help="Server IP for HF LLM Chat API", ) self.add_argument( "-p", "--port", type=int, default=7860, help="Server Port for HF LLM Chat API", ) self.args = self.parse_args(sys.argv[1:]) if __name__ == "__main__": args = ArgParser().args uvicorn.run(app, host=args.server, port=args.port, reload=False)