Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import threading | |
import tempfile | |
import ctypes | |
import gc | |
from typing import Optional | |
from fastapi import FastAPI, Depends, HTTPException | |
from fastapi.responses import FileResponse, JSONResponse | |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
from pydantic import BaseModel | |
from huggingface_hub import hf_hub_download | |
from torch import no_grad, package | |
from pydub import AudioSegment | |
import uvicorn | |
from accentor import accentification, stress_replace_and_shift | |
import argparse | |
from passlib.context import CryptContext | |
app = FastAPI() | |
# Set environment variable for Hugging Face cache directory | |
os.environ["HF_HOME"] = "/app/.cache" | |
tts_kwargs = { | |
"speaker_name": "uk", | |
"language_name": "uk", | |
} | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
class User(BaseModel): | |
username: str | |
password: str | |
# Password hashing context | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
# Load username and password from environment variables | |
username = os.getenv("XCHE_API_KEY") | |
password = os.getenv("XCHE_PASSWORD") | |
# In-memory storage for simplicity; in production use a database | |
fake_users_db = { | |
username: { | |
"username": username, | |
"password": get_password_hash(password) # Pre-hashed password | |
} | |
} | |
def get_user(db, username: str): | |
if username in db: | |
user_dict = db[username] | |
return User(**user_dict) | |
def authenticate_user(fake_db, username: str, password: str): | |
user = get_user(fake_db, username) | |
if not user: | |
return False | |
if not verify_password(password, user.password): | |
return False | |
return user | |
async def login(form_data: OAuth2PasswordRequestForm = Depends()): | |
user = authenticate_user(fake_users_db, form_data.username, form_data.password) | |
if not user: | |
raise HTTPException( | |
status_code=400, | |
detail="Incorrect username or password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return {"access_token": user.username, "token_type": "bearer"} | |
def check_api_token(token: str = Depends(oauth2_scheme)): | |
user = get_user(fake_users_db, token) | |
if not user: | |
raise HTTPException(status_code=403, detail="Invalid or missing API Key") | |
return user | |
def trim_memory(): | |
libc = ctypes.CDLL("libc.so.6") | |
libc.malloc_trim(0) | |
gc.collect() | |
def init_models(): | |
models = {} | |
model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt") | |
importer = package.PackageImporter(model_path) | |
models["lada"] = importer.load_pickle("tts_models", "model") | |
return models | |
async def tts(request: str, user: User = Depends(check_api_token)): | |
print(request) | |
accented_text = accentification(request) | |
plussed_text = stress_replace_and_shift(accented_text) | |
synt = models["lada"] | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp: | |
with no_grad(): | |
wav_data = synt.tts(plussed_text, **tts_kwargs) | |
synt.save_wav(wav_data, wav_fp) | |
threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start() | |
return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"}) | |
async def download_audio(audio_path: str): | |
return FileResponse(audio_path, media_type='audio/wav') | |
models = init_models() | |
def delete_file_after_delay(file_path: str, delay: int): | |
time.sleep(delay) | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
class ArgParser(argparse.ArgumentParser): | |
def __init__(self, *args, **kwargs): | |
super(ArgParser, self).__init__(*args, **kwargs) | |
self.add_argument( | |
"-s", "--server", type=str, default="0.0.0.0", | |
help="Server IP for HF LLM Chat API", | |
) | |
self.add_argument( | |
"-p", "--port", type=int, default=7860, | |
help="Server Port for HF LLM Chat API", | |
) | |
self.args = self.parse_args(sys.argv[1:]) | |
if __name__ == "__main__": | |
args = ArgParser().args | |
uvicorn.run(app, host=args.server, port=args.port, reload=False) | |