File size: 7,694 Bytes
4c10907
7d35d1e
 
 
 
4c10907
7d35d1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c10907
7d35d1e
 
4c10907
7d35d1e
 
 
 
 
4c10907
7d35d1e
 
 
4c10907
7d35d1e
4c10907
7d35d1e
 
 
 
4c10907
7d35d1e
4c10907
7d35d1e
 
 
4c10907
7d35d1e
 
 
 
 
 
 
 
4c10907
7d35d1e
 
4c10907
7d35d1e
 
4c10907
7d35d1e
 
 
 
 
 
 
4c10907
7d35d1e
4c10907
7d35d1e
 
 
4c10907
7d35d1e
4c10907
7d35d1e
 
 
 
 
 
 
 
4c10907
7d35d1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

import torch
import torchaudio
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
import numpy as np
import json
import os
from safetensors.torch import load_file

# Imports from the jamify library
from jam.model.cfm import CFM
from jam.model.dit import DiT
from jam.model.vae import StableAudioOpenVAE
from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
from muq import MuQMuLan

# Helper functions adapted from jamify/src/jam/infer.py
def get_negative_style_prompt(device, file_path):
    vocal_style = np.load(file_path)
    vocal_style = torch.from_numpy(vocal_style).to(device)
    return vocal_style.half()

def normalize_audio(audio):
    audio = audio - audio.mean(-1, keepdim=True)
    audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
    return audio

class Jamify:
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = torch.device(device)
        
        # --- FIX: Point to the local jamify repository for config and public files ---
        #jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
        
        print("Downloading main model checkpoint...")
        model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
        self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
        
        # Use local config and data files
        config_path = os.path.join(model_repo_path, "jam_infer.yaml")
        self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
        tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
        silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
        print("Loading configuration...")
        self.config = OmegaConf.load(config_path)
        self.config.data.train_dataset.silence_latent_path = silence_latent_path
        
        # --- FIX: Override the relative paths in the config with absolute paths ---
        self.config.data.train_dataset.tokenizer_path = tokenizer_path
        self.config.evaluation.dataset.tokenizer_path = tokenizer_path
        self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
        
        print("Loading VAE model...")
        self.vae = StableAudioOpenVAE().to(self.device).eval()
        
        print("Loading CFM model...")
        self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
        
        print("Loading MuQ style model...")
        self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()

        print("Setting up dataset processor...")
        dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
        enhance_webdataset_config(dataset_cfg)
        dataset_cfg.multiple_styles = False
        self.dataset_processor = DiffusionWebDataset(**dataset_cfg)

        print("Jamify model loaded successfully.")

    def _load_cfm_model(self, model_config, checkpoint_path):
        dit_config = model_config["dit"].copy()
        if "text_num_embeds" not in dit_config:
            dit_config["text_num_embeds"] = 256
        
        model = CFM(
            transformer=DiT(**dit_config),
            **model_config["cfm"]
        ).to(self.device)
        
        state_dict = load_file(checkpoint_path)
        model.load_state_dict(state_dict, strict=False)
        return model.eval()

    def _generate_style_embedding_from_audio(self, audio_path):
        waveform, sample_rate = torchaudio.load(audio_path)
        if sample_rate != 24000:
            resampler = torchaudio.transforms.Resample(sample_rate, 24000)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        waveform = waveform.squeeze(0).to(self.device)
        
        with torch.inference_mode():
            style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
        return style_embedding[0]

    def _generate_style_embedding_from_prompt(self, prompt):
        with torch.inference_mode():
            style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
        return style_embedding

    def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
        print("Starting prediction...")
        
        if reference_audio_path:
            print(f"Generating style from audio: {reference_audio_path}")
            style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
        elif style_prompt:
            print(f"Generating style from prompt: '{style_prompt}'")
            style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
        else:
            print("No style provided, using zero embedding.")
            style_embedding = torch.zeros(512, device=self.device)

        print(f"Loading lyrics from: {lyrics_json_path}")
        with open(lyrics_json_path, 'r') as f:
            lrc_data = json.load(f)
        if 'word' not in lrc_data:
            lrc_data = {'word': lrc_data}

        frame_rate = 21.5
        num_frames = int(duration_sec * frame_rate)
        fake_latent = torch.randn(128, num_frames)
        
        sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
        
        print("Processing sample...")
        processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
        if processed_sample is None:
            raise ValueError("Failed to process the provided lyrics and style.")

        batch = self.dataset_processor.custom_collate_fn([processed_sample])
        
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                batch[key] = value.to(self.device)

        print("Generating audio latent...")
        with torch.inference_mode():
            batch_size = 1
            text = batch["lrc"]
            style_prompt_tensor = batch["prompt"]
            start_time = batch["start_time"]
            duration_abs = batch["duration_abs"]
            duration_rel = batch["duration_rel"]
            
            cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
            pred_frames = [(0, self.cfm_model.max_frames)]
            
            negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
            negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)

            sample_kwargs = self.config.evaluation.sample_kwargs
            sample_kwargs.steps = steps
            latents, _ = self.cfm_model.sample(
                cond=cond, text=text, style_prompt=style_prompt_tensor,
                duration_abs=duration_abs, duration_rel=duration_rel,
                negative_style_prompt=negative_style_prompt, start_time=start_time,
                latent_pred_segments=pred_frames, **sample_kwargs)
            
            latent = latents[0][0]

        print("Decoding latent to audio...")
        latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
        pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
        
        pred_audio = normalize_audio(pred_audio)
        
        sample_rate = 44100
        trim_samples = int(duration_sec * sample_rate)
        if pred_audio.shape[1] > trim_samples:
            pred_audio = pred_audio[:, :trim_samples]
            
        output_path = "generated_song.mp3"
        print(f"Saving audio to {output_path}")
        torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
        
        return output_path