Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,694 Bytes
4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e 4c10907 7d35d1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import torchaudio
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
import numpy as np
import json
import os
from safetensors.torch import load_file
# Imports from the jamify library
from jam.model.cfm import CFM
from jam.model.dit import DiT
from jam.model.vae import StableAudioOpenVAE
from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
from muq import MuQMuLan
# Helper functions adapted from jamify/src/jam/infer.py
def get_negative_style_prompt(device, file_path):
vocal_style = np.load(file_path)
vocal_style = torch.from_numpy(vocal_style).to(device)
return vocal_style.half()
def normalize_audio(audio):
audio = audio - audio.mean(-1, keepdim=True)
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
return audio
class Jamify:
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
self.device = torch.device(device)
# --- FIX: Point to the local jamify repository for config and public files ---
#jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
print("Downloading main model checkpoint...")
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
# Use local config and data files
config_path = os.path.join(model_repo_path, "jam_infer.yaml")
self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
print("Loading configuration...")
self.config = OmegaConf.load(config_path)
self.config.data.train_dataset.silence_latent_path = silence_latent_path
# --- FIX: Override the relative paths in the config with absolute paths ---
self.config.data.train_dataset.tokenizer_path = tokenizer_path
self.config.evaluation.dataset.tokenizer_path = tokenizer_path
self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
print("Loading VAE model...")
self.vae = StableAudioOpenVAE().to(self.device).eval()
print("Loading CFM model...")
self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
print("Loading MuQ style model...")
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
print("Setting up dataset processor...")
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
enhance_webdataset_config(dataset_cfg)
dataset_cfg.multiple_styles = False
self.dataset_processor = DiffusionWebDataset(**dataset_cfg)
print("Jamify model loaded successfully.")
def _load_cfm_model(self, model_config, checkpoint_path):
dit_config = model_config["dit"].copy()
if "text_num_embeds" not in dit_config:
dit_config["text_num_embeds"] = 256
model = CFM(
transformer=DiT(**dit_config),
**model_config["cfm"]
).to(self.device)
state_dict = load_file(checkpoint_path)
model.load_state_dict(state_dict, strict=False)
return model.eval()
def _generate_style_embedding_from_audio(self, audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != 24000:
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze(0).to(self.device)
with torch.inference_mode():
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
return style_embedding[0]
def _generate_style_embedding_from_prompt(self, prompt):
with torch.inference_mode():
style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
return style_embedding
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
print("Starting prediction...")
if reference_audio_path:
print(f"Generating style from audio: {reference_audio_path}")
style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
elif style_prompt:
print(f"Generating style from prompt: '{style_prompt}'")
style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
else:
print("No style provided, using zero embedding.")
style_embedding = torch.zeros(512, device=self.device)
print(f"Loading lyrics from: {lyrics_json_path}")
with open(lyrics_json_path, 'r') as f:
lrc_data = json.load(f)
if 'word' not in lrc_data:
lrc_data = {'word': lrc_data}
frame_rate = 21.5
num_frames = int(duration_sec * frame_rate)
fake_latent = torch.randn(128, num_frames)
sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
print("Processing sample...")
processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
if processed_sample is None:
raise ValueError("Failed to process the provided lyrics and style.")
batch = self.dataset_processor.custom_collate_fn([processed_sample])
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.to(self.device)
print("Generating audio latent...")
with torch.inference_mode():
batch_size = 1
text = batch["lrc"]
style_prompt_tensor = batch["prompt"]
start_time = batch["start_time"]
duration_abs = batch["duration_abs"]
duration_rel = batch["duration_rel"]
cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
pred_frames = [(0, self.cfm_model.max_frames)]
negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
sample_kwargs = self.config.evaluation.sample_kwargs
sample_kwargs.steps = steps
latents, _ = self.cfm_model.sample(
cond=cond, text=text, style_prompt=style_prompt_tensor,
duration_abs=duration_abs, duration_rel=duration_rel,
negative_style_prompt=negative_style_prompt, start_time=start_time,
latent_pred_segments=pred_frames, **sample_kwargs)
latent = latents[0][0]
print("Decoding latent to audio...")
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
pred_audio = normalize_audio(pred_audio)
sample_rate = 44100
trim_samples = int(duration_sec * sample_rate)
if pred_audio.shape[1] > trim_samples:
pred_audio = pred_audio[:, :trim_samples]
output_path = "generated_song.mp3"
print(f"Saving audio to {output_path}")
torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
return output_path
|