import librosa import numpy as np import torch from model.DiffSynthSampler import DiffSynthSampler from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT import mido import torchaudio.transforms as transforms from tqdm import tqdm # def pitch_shift_audio(waveform, sample_rate, n_steps, device='cpu', n_fft=1024, hop_length=None): # # 如果输入是 numpy 数组,则转换为 torch.Tensor # if isinstance(waveform, np.ndarray): # waveform = torch.from_numpy(waveform) # # # 设置 hop_length 为 n_fft 的一半(合理的默认值),以减少 STFT 操作的内存开销 # if hop_length is None: # hop_length = n_fft // 4 # # # 将 waveform 移动到指定设备上 # waveform = waveform.to(device, dtype=torch.float32) # # # 创建 pitch_shift 变换,并移动到指定设备上 # pitch_shift = transforms.PitchShift( # sample_rate=sample_rate, # n_steps=n_steps, # n_fft=n_fft, # hop_length=hop_length # ).to(device) # # # 执行变换,并将结果从设备移动到 CPU 上,最后转换为 numpy 数组 # shifted_waveform = pitch_shift(waveform).detach().cpu().numpy() # # return shifted_waveform def pitch_shift_librosa(waveform, sample_rate, total_steps, step_size=4, n_fft=4096, hop_length=None): # librosa 需要输入的是 numpy 数组 if isinstance(waveform, torch.Tensor): waveform = waveform.numpy() # 如果 hop_length 未提供,则使用 n_fft 的四分之一作为默认值 if hop_length is None: hop_length = n_fft // 4 # 逐步进行 pitch shift,每次提升 step_size 个半音 current_waveform = waveform num_steps = int(np.ceil(total_steps / step_size)) for i in range(num_steps): step = min(step_size, total_steps - i * step_size) # 确保最后一步不会超过 total_steps current_waveform = librosa.effects.pitch_shift( current_waveform, sr=sample_rate, n_steps=step, n_fft=n_fft, hop_length=hop_length ) return current_waveform class NoteEvent: def __init__(self, note, velocity, start_time, duration): self.note = note self.velocity = velocity self.start_time = start_time # In ticks self.duration = duration # In ticks def __str__(self): return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}" class Track: def __init__(self, track, ticks_per_beat): self.tempo_events = self._parse_tempo_events(track) self.events = self._parse_note_events(track) self.ticks_per_beat = ticks_per_beat def _parse_tempo_events(self, track): tempo_events = [] current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat for msg in track: if msg.type == 'set_tempo': tempo_events.append((msg.time, msg.tempo)) elif not msg.is_meta: tempo_events.append((msg.time, current_tempo)) return tempo_events def _parse_note_events(self, track): events = [] start_time = 0 for msg in track: if not msg.is_meta: start_time += msg.time if msg.type == 'note_on' and msg.velocity > 0: note_on_time = start_time elif msg.type == 'note_on' and msg.velocity == 0: duration = start_time - note_on_time events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration)) return events def synthesize_track(self, diffSynthSampler, sample_rate=16000): track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32) current_tempo = 500000 # Start with default MIDI tempo 120 BPM duration_note_mapping = {} for event in tqdm(self.events[:25]): current_tempo = self._get_tempo_at(event.start_time) seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) start_time_sec = event.start_time * seconds_per_tick # Todo: set a minimum duration duration_sec = event.duration * seconds_per_tick duration_sec = max(duration_sec, 0.75) start_sample = int(start_time_sec * sample_rate) if not (str(duration_sec) in duration_note_mapping): note_sample = diffSynthSampler(event.velocity, duration_sec) duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample)) # note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) # note_audio = pitch_shift_audio(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) note_audio = pitch_shift_librosa(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) end_sample = start_sample + len(note_audio) track_audio[start_sample:end_sample] += note_audio return track_audio def _get_tempo_at(self, time_tick): current_tempo = 500000 # Start with default MIDI tempo 120 BPM elapsed_ticks = 0 for tempo_change in self.tempo_events: if elapsed_ticks + tempo_change[0] > time_tick: return current_tempo elapsed_ticks += tempo_change[0] current_tempo = tempo_change[1] return current_tempo def _get_total_time(self): total_time = 0 current_tempo = 500000 # Start with default MIDI tempo 120 BPM for event in self.events: current_tempo = self._get_tempo_at(event.start_time) seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) total_time += event.duration * seconds_per_tick return total_time class DiffSynth: def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device, model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False): self.noise_prediction_model = noise_prediction_model self.VAE_quantizer = VAE_quantizer self.VAE_decoder = VAE_decoder self.device = device self.model_sample_rate = model_sample_rate self.timesteps = timesteps self.channels = channels self.freq_resolution = freq_resolution self.time_resolution = time_resolution self.height = int(freq_resolution/VAE_scale) self.VAE_scale = VAE_scale self.squared = squared self.text_encoder = text_encoder self.CLAP_tokenizer = CLAP_tokenizer # instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler) self.instruments_configs = instruments_configs self.diffSynthSamplers = {} self._update_instruments() def _update_instruments(self): def diffSynthSamplerWrapper(instruments_config): def diffSynthSampler(velocity, duration_sec, sample_rate=16000): condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device) sample_steps = instruments_config['sample_steps'] sampler = instruments_config['sampler'] noising_strength = instruments_config['noising_strength'] latent_representation = instruments_config['latent_representation'] attack = instruments_config['attack'] before_release = instruments_config['before_release'] assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate" width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale) mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True) mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32))) # mask = 1, freeze latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device) latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0 latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0 latent_representations, _ = \ mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width), noising_strength=noising_strength, condition=condition, guide_img=latent_representation, mask=latent_mask, return_tensor=True, sampler=sampler, use_dynamic_mask=True, end_noise_level_ratio=0.0, mask_flexivity=1.0) latent_representations = latent_representations[-1] quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations) # Todo: remove hard-coding flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder, quantized_latent_representations, resolution=( 512, width * self.VAE_scale), original_STFT_batch=None, ) return rec_signals[0] return diffSynthSampler for key in self.instruments_configs.keys(): self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key]) def get_music(self, mid, instrument_names, sample_rate=16000): tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] assert len(tracks) <= len(instrument_names), f"len(tracks) = {len(tracks)} > {len(instrument_names)} = len(instrument_names)" track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)] # 将所有音轨填充至最长音轨的长度,以便它们可以被叠加 max_length = max(len(audio) for audio in track_audios) full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零 for audio in track_audios: # 音轨可能不够长,需要填充零 padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant') full_audio += padded_audio # 叠加音轨 return full_audio