Alex Ergasti commited on
Commit
9634dc8
·
1 Parent(s): ab0a826

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -19
app.py CHANGED
@@ -6,8 +6,10 @@ torch.backends.cudnn.allow_tf32 = True
6
  import os
7
  import spaces
8
  from diffusers.models import AutoencoderKL
9
- from models import FLAV_models
10
-
 
 
11
  from diffusion.rectified_flow import RectifiedFlow
12
 
13
  from diffusers.training_utils import EMAModel
@@ -27,7 +29,7 @@ AUDIO_T_PER_FRAME = 1600 // 160
27
  vae = None
28
  model = None
29
  vocoder = None
30
- audio_scale = 3.50
31
 
32
 
33
  def setup_models():
@@ -37,25 +39,13 @@ def setup_models():
37
  vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
38
  vae.eval()
39
 
40
- model = FLAV_models["FLAV-B/1"](
41
- latent_size= 256//8,
42
- in_channels = 4,
43
- num_classes = 0,
44
- predict_frames = 10,
45
- causal_attn = True,
46
- )
47
-
48
- ckpt_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="aist-ema.pth")
49
-
50
- state_dict = torch.load(ckpt_path, map_location="cpu")
51
 
52
- ema = EMAModel(model.parameters())
53
- ema.load_state_dict(state_dict)
54
- ema.copy_to(model.parameters())
55
 
56
- hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/config.json")
57
- vocoder_path = hf_hub_download(repo_id="MaverickAlex/R-FLAV", filename="vocoder-aist/vocoder.pt")
58
 
 
 
59
  vocoder_path = vocoder_path.replace("vocoder.pt", "")
60
  vocoder = Generator.from_pretrained(vocoder_path)
61
 
 
6
  import os
7
  import spaces
8
  from diffusers.models import AutoencoderKL
9
+ from models import FLAV
10
+ from huggingface_hub import hf_hub_download
11
+ import torch
12
+
13
  from diffusion.rectified_flow import RectifiedFlow
14
 
15
  from diffusers.training_utils import EMAModel
 
29
  vae = None
30
  model = None
31
  vocoder = None
32
+ audio_scale = 3.5009668382765917
33
 
34
 
35
  def setup_models():
 
39
  vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema")
40
  vae.eval()
41
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ model_ckpt = "MaverickAlex/R-FLAV-B-1-AIST" # MaverickAlex/R-FLAV-B-1-LS
 
 
44
 
45
+ model = FLAV.from_pretrained(model_ckpt)
 
46
 
47
+ hf_hub_download(repo_id=model_ckpt, filename="vocoder/config.json")
48
+ vocoder_path = hf_hub_download(repo_id=model_ckpt, filename="vocoder/vocoder.pt")
49
  vocoder_path = vocoder_path.replace("vocoder.pt", "")
50
  vocoder = Generator.from_pretrained(vocoder_path)
51