File size: 4,758 Bytes
e5f531d
 
 
7da5a9a
 
 
e5f531d
 
 
 
 
7da5a9a
 
 
e5f531d
 
7da5a9a
e5f531d
7da5a9a
 
ca466ef
7da5a9a
065c08d
7da5a9a
66108fb
cfaeb09
e5f531d
66108fb
065c08d
 
 
 
 
 
 
e5f531d
7da5a9a
 
 
 
 
 
 
a6f4e74
 
7da5a9a
 
 
 
 
 
 
 
 
 
 
a6f4e74
dca572a
7da5a9a
a6f4e74
 
 
7da5a9a
 
 
 
a6f4e74
 
 
 
7da5a9a
a6f4e74
 
 
7da5a9a
a6f4e74
7da5a9a
a6f4e74
7da5a9a
 
 
9a18fdd
a6f4e74
7da5a9a
 
a6f4e74
7da5a9a
 
a6f4e74
7da5a9a
 
a6f4e74
 
15a7075
a6f4e74
7da5a9a
 
 
 
 
 
 
 
065c08d
7da5a9a
 
 
 
 
a6f4e74
58373da
7da5a9a
 
 
 
 
 
065c08d
7da5a9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import ctypes
import gc
import os
import sys
import tempfile
import threading
import time

import uvicorn
from fastapi import FastAPI, Depends, HTTPException
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from huggingface_hub import hf_hub_download
from passlib.context import CryptContext
from pydantic import BaseModel
from torch import no_grad, package

from accentor import accentification, stress_replace_and_shift

app = FastAPI(docs_url=None, redoc_url=None)

# Set environment variable for Hugging Face cache directory and Stanza resources directory
os.environ["HF_HOME"] = "/app/.cache"
os.environ["STANZA_RESOURCES_DIR"] = "/app/stanza_resources"
os.environ["MODEL_DIR"] = "/app/model"

# Create necessary directories if they don't exist
os.makedirs(os.environ["STANZA_RESOURCES_DIR"], exist_ok=True)
os.makedirs(os.environ["MODEL_DIR"], exist_ok=True)

# Download the Ukrainian models for Stanza
import stanza

stanza.download('uk', model_dir=os.environ["STANZA_RESOURCES_DIR"])

tts_kwargs = {
    "speaker_name": "uk",
    "language_name": "uk",
}

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

class Auth(BaseModel):
    api_key: str
    password: str

# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_password_hash(password):
    return pwd_context.hash(password)

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

api_key = os.getenv("XCHE_API_KEY")
password = os.getenv("XCHE_PASSWORD")

fake_data_db = {
    api_key: {
        "api_key": api_key,
        "password": get_password_hash(password)  # Pre-hashed password
    }
}

def get_api_key(db, api_key: str):
    if api_key in db:
        api_dict = db[api_key]
        return Auth(**api_dict)

def authenticate(fake_db, api_key: str, password: str):
    api_data = get_api_key(fake_db, api_key)
    if not api_data:
        return False
    if not verify_password(password, api_data.password):
        return False
    return api_data

@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    api_data = authenticate(fake_data_db, form_data.username, form_data.password)
    if not api_data:
        raise HTTPException(
            status_code=400,
            detail="Incorrect API KEY or Password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    return {"access_token": api_data.api_key, "token_type": "bearer"}

def check_api_token(token: str = Depends(oauth2_scheme)):
    api_data = get_api_key(fake_data_db, token)
    if not api_data:
        raise HTTPException(status_code=403, detail="Invalid or missing API Key")
    return api_data

def trim_memory():
    libc = ctypes.CDLL("libc.so.6")
    libc.malloc_trim(0)
    gc.collect()

def init_models():
    models = {}
    model_path = hf_hub_download("theodotus/tts-vits-lada-uk", "model.pt", cache_dir=os.environ["MODEL_DIR"])
    importer = package.PackageImporter(model_path)
    models["lada"] = importer.load_pickle("tts_models", "model")
    return models

@app.post("/create_audio")
async def tts(request: str, api_data: Auth = Depends(check_api_token)):
    accented_text = accentification(request, "vocab")
    plussed_text = stress_replace_and_shift(accented_text)

    synt = models["lada"]
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_fp:
        with no_grad():
            wav_data = synt.tts(plussed_text, **tts_kwargs)
            synt.save_wav(wav_data, wav_fp.name)

        threading.Thread(target=delete_file_after_delay, args=(wav_fp.name, 300)).start()
        return JSONResponse(content={"audio_url": f"https://pro100sata-xche-audio.hf.space/download_audio?audio_path={wav_fp.name}"})

@app.get("/download_audio")
async def download_audio(audio_path: str):
    return FileResponse(audio_path, media_type='audio/wav')

models = init_models()

def delete_file_after_delay(file_path: str, delay: int):
    time.sleep(delay)
    if os.path.exists(file_path):
        os.remove(file_path)

class ArgParser(argparse.ArgumentParser):
    def __init__(self, *args, **kwargs):
        super(ArgParser, self).__init__(*args, **kwargs)
        self.add_argument(
            "-s", "--server", type=str, default="0.0.0.0",
            help="Server IP for HF LLM Chat API",
        )
        self.add_argument(
            "-p", "--port", type=int, default=7860,
            help="Server Port for HF LLM Chat API",
        )
        self.args = self.parse_args(sys.argv[1:])

if __name__ == "__main__":
    args = ArgParser().args
    uvicorn.run(app, host=args.server, port=args.port, reload=False)