xche_audio / app.py
Yarik
Add application file
15a7075
raw
history blame
4.43 kB
import os
import sys
import time
import threading
import tempfile
import ctypes
import gc
from typing import Optional
from fastapi import FastAPI, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from torch import no_grad, package
from pydub import AudioSegment
import uvicorn
from accentor import accentification, stress_replace_and_shift
import argparse
from passlib.context import CryptContext
app = FastAPI()
# Set environment variable for Hugging Face cache directory
os.environ["HF_HOME"] = "/app/.cache"
tts_kwargs = {
"speaker_name": "uk",
"language_name": "uk",
}
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class User(BaseModel):
username: str
password: str
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_password_hash(password):
return pwd_context.hash(password)
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
# Load username and password from environment variables
username = os.getenv("XCHE_API_KEY")
password = os.getenv("XCHE_PASSWORD")
# In-memory storage for simplicity; in production use a database
fake_users_db = {
username: {
"username": username,
"password": get_password_hash(password) # Pre-hashed password
}
}
def get_user(db, username: str):
if username in db:
user_dict = db[username]
return User(**user_dict)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not verify_password(password, user.password):
return False
return user
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=400,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
return {"access_token": user.username, "token_type": "bearer"}
def check_api_token(token: str = Depends(oauth2_scheme)):
user = get_user(fake_users_db, token)
if not user:
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
return user
def trim_memory():
libc = ctypes.CDLL("libc.so.6")
libc.malloc_trim(0)
gc.collect()
def init_models():
models = {}
model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt")
importer = package.PackageImporter(model_path)
models["lada"] = importer.load_pickle("tts_models", "model")
return models
@app.post("/create_audio")
async def tts(request: str, user: User = Depends(check_api_token)):
print(request)
accented_text = accentification(request)
plussed_text = stress_replace_and_shift(accented_text)
synt = models["lada"]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp:
with no_grad():
wav_data = synt.tts(plussed_text, **tts_kwargs)
synt.save_wav(wav_data, wav_fp)
threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start()
return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"})
@app.get("/download_audio")
async def download_audio(audio_path: str):
return FileResponse(audio_path, media_type='audio/wav')
models = init_models()
def delete_file_after_delay(file_path: str, delay: int):
time.sleep(delay)
if os.path.exists(file_path):
os.remove(file_path)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s", "--server", type=str, default="0.0.0.0",
help="Server IP for HF LLM Chat API",
)
self.add_argument(
"-p", "--port", type=int, default=7860,
help="Server Port for HF LLM Chat API",
)
self.args = self.parse_args(sys.argv[1:])
if __name__ == "__main__":
args = ArgParser().args
uvicorn.run(app, host=args.server, port=args.port, reload=False)