Spaces:
Running
on
Zero
Running
on
Zero
Alex Ergasti
commited on
Commit
·
9634dc8
1
Parent(s):
ab0a826
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,10 @@ torch.backends.cudnn.allow_tf32 = True
|
|
6 |
import os
|
7 |
import spaces
|
8 |
from diffusers.models import AutoencoderKL
|
9 |
-
from models import
|
10 |
-
|
|
|
|
|
11 |
from diffusion.rectified_flow import RectifiedFlow
|
12 |
|
13 |
from diffusers.training_utils import EMAModel
|
@@ -27,7 +29,7 @@ AUDIO_T_PER_FRAME = 1600 // 160
|
|
27 |
vae = None
|
28 |
model = None
|
29 |
vocoder = None
|
30 |
-
audio_scale = 3.
|
31 |
|
32 |
|
33 |
def setup_models():
|
@@ -37,25 +39,13 @@ def setup_models():
|
|
37 |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
38 |
vae.eval()
|
39 |
|
40 |
-
model = FLAV_models["FLAV-B/1"](
|
41 |
-
latent_size= 256//8,
|
42 |
-
in_channels = 4,
|
43 |
-
num_classes = 0,
|
44 |
-
predict_frames = 10,
|
45 |
-
causal_attn = True,
|
46 |
-
)
|
47 |
-
|
48 |
-
ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth")
|
49 |
-
|
50 |
-
state_dict = torch.load(ckpt_path, map_location="cpu")
|
51 |
|
52 |
-
|
53 |
-
ema.load_state_dict(state_dict)
|
54 |
-
ema.copy_to(model.parameters())
|
55 |
|
56 |
-
|
57 |
-
vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt")
|
58 |
|
|
|
|
|
59 |
vocoder_path = vocoder_path.replace("vocoder.pt", "")
|
60 |
vocoder = Generator.from_pretrained(vocoder_path)
|
61 |
|
|
|
6 |
import os
|
7 |
import spaces
|
8 |
from diffusers.models import AutoencoderKL
|
9 |
+
from models import FLAV
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
import torch
|
12 |
+
|
13 |
from diffusion.rectified_flow import RectifiedFlow
|
14 |
|
15 |
from diffusers.training_utils import EMAModel
|
|
|
29 |
vae = None
|
30 |
model = None
|
31 |
vocoder = None
|
32 |
+
audio_scale = 3.5009668382765917
|
33 |
|
34 |
|
35 |
def setup_models():
|
|
|
39 |
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
|
40 |
vae.eval()
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
model_ckpt = "MaverickAlex/R-FLAV-B-1-AIST" # MaverickAlex/R-FLAV-B-1-LS
|
|
|
|
|
44 |
|
45 |
+
model = FLAV.from_pretrained(model_ckpt)
|
|
|
46 |
|
47 |
+
hf_hub_download(repo_id=model_ckpt, filename="vocoder/config.json")
|
48 |
+
vocoder_path = hf_hub_download(repo_id=model_ckpt, filename="vocoder/vocoder.pt")
|
49 |
vocoder_path = vocoder_path.replace("vocoder.pt", "")
|
50 |
vocoder = Generator.from_pretrained(vocoder_path)
|
51 |
|