Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| import threading | |
| import tempfile | |
| import ctypes | |
| import gc | |
| from typing import Optional | |
| from fastapi import FastAPI, Depends, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from torch import no_grad, package | |
| from pydub import AudioSegment | |
| import uvicorn | |
| from accentor import accentification, stress_replace_and_shift | |
| import argparse | |
| from passlib.context import CryptContext | |
| app = FastAPI() | |
| # Set environment variable for Hugging Face cache directory | |
| os.environ["HF_HOME"] = "/app/.cache" | |
| tts_kwargs = { | |
| "speaker_name": "uk", | |
| "language_name": "uk", | |
| } | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| class Auth(BaseModel): | |
| api_key: str | |
| password: str | |
| # Password hashing context | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| def get_password_hash(password): | |
| return pwd_context.hash(password) | |
| def verify_password(plain_password, hashed_password): | |
| return pwd_context.verify(plain_password, hashed_password) | |
| api_key = os.getenv("XCHE_API_KEY") | |
| password = os.getenv("XCHE_PASSWORD") | |
| fake_data_db = { | |
| api_key: { | |
| "api_key": api_key, | |
| "password": get_password_hash(password) # Pre-hashed password | |
| } | |
| } | |
| def get_api_key(db, api_key: str): | |
| if api_key in db: | |
| api_dict = db[api_key] | |
| return Auth(**api_dict) | |
| def authenticate(fake_db, api_key: str, password: str): | |
| api_key = get_api_key(fake_db, api_key) | |
| if not api_key: | |
| return False | |
| if not verify_password(password, api_key.password): | |
| return False | |
| return api_key | |
| async def login(form_data: OAuth2PasswordRequestForm = Depends()): | |
| api_key = authenticate(fake_data_db, form_data.api_key, form_data.password) | |
| if not api_key: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Incorrect API KEY or Password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return {"access_token": api_key.api_key, "token_type": "bearer"} | |
| def check_api_token(token: str = Depends(oauth2_scheme)): | |
| api_key = get_api_key(fake_data_db, token) | |
| if not api_key: | |
| raise HTTPException(status_code=403, detail="Invalid or missing API KEY") | |
| return api_key | |
| def trim_memory(): | |
| libc = ctypes.CDLL("libc.so.6") | |
| libc.malloc_trim(0) | |
| gc.collect() | |
| def init_models(): | |
| models = {} | |
| model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt") | |
| importer = package.PackageImporter(model_path) | |
| models["lada"] = importer.load_pickle("tts_models", "model") | |
| return models | |
| async def tts(request: str, api_key: Auth = Depends(check_api_token)): | |
| accented_text = accentification(request) | |
| plussed_text = stress_replace_and_shift(accented_text) | |
| synt = models["lada"] | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp: | |
| with no_grad(): | |
| wav_data = synt.tts(plussed_text, **tts_kwargs) | |
| synt.save_wav(wav_data, wav_fp) | |
| threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start() | |
| return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"}) | |
| async def download_audio(audio_path: str): | |
| return FileResponse(audio_path, media_type='audio/wav') | |
| models = init_models() | |
| def delete_file_after_delay(file_path: str, delay: int): | |
| time.sleep(delay) | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", "--server", type=str, default="0.0.0.0", | |
| help="Server IP for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-p", "--port", type=int, default=7860, | |
| help="Server Port for HF LLM Chat API", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| uvicorn.run(app, host=args.server, port=args.port, reload=False) | |