Spaces:
Running
Running
import librosa | |
import numpy as np | |
import torch | |
from model.DiffSynthSampler import DiffSynthSampler | |
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT | |
import mido | |
import torchaudio.transforms as transforms | |
from tqdm import tqdm | |
# def pitch_shift_audio(waveform, sample_rate, n_steps, device='cpu', n_fft=1024, hop_length=None): | |
# # 如果输入是 numpy 数组,则转换为 torch.Tensor | |
# if isinstance(waveform, np.ndarray): | |
# waveform = torch.from_numpy(waveform) | |
# | |
# # 设置 hop_length 为 n_fft 的一半(合理的默认值),以减少 STFT 操作的内存开销 | |
# if hop_length is None: | |
# hop_length = n_fft // 4 | |
# | |
# # 将 waveform 移动到指定设备上 | |
# waveform = waveform.to(device, dtype=torch.float32) | |
# | |
# # 创建 pitch_shift 变换,并移动到指定设备上 | |
# pitch_shift = transforms.PitchShift( | |
# sample_rate=sample_rate, | |
# n_steps=n_steps, | |
# n_fft=n_fft, | |
# hop_length=hop_length | |
# ).to(device) | |
# | |
# # 执行变换,并将结果从设备移动到 CPU 上,最后转换为 numpy 数组 | |
# shifted_waveform = pitch_shift(waveform).detach().cpu().numpy() | |
# | |
# return shifted_waveform | |
def pitch_shift_librosa(waveform, sample_rate, total_steps, step_size=4, n_fft=4096, hop_length=None): | |
# librosa 需要输入的是 numpy 数组 | |
if isinstance(waveform, torch.Tensor): | |
waveform = waveform.numpy() | |
# 如果 hop_length 未提供,则使用 n_fft 的四分之一作为默认值 | |
if hop_length is None: | |
hop_length = n_fft // 4 | |
# 逐步进行 pitch shift,每次提升 step_size 个半音 | |
current_waveform = waveform | |
num_steps = int(np.ceil(total_steps / step_size)) | |
for i in range(num_steps): | |
step = min(step_size, total_steps - i * step_size) # 确保最后一步不会超过 total_steps | |
current_waveform = librosa.effects.pitch_shift( | |
current_waveform, sr=sample_rate, n_steps=step, | |
n_fft=n_fft, hop_length=hop_length | |
) | |
return current_waveform | |
class NoteEvent: | |
def __init__(self, note, velocity, start_time, duration): | |
self.note = note | |
self.velocity = velocity | |
self.start_time = start_time # In ticks | |
self.duration = duration # In ticks | |
def __str__(self): | |
return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}" | |
class Track: | |
def __init__(self, track, ticks_per_beat): | |
self.tempo_events = self._parse_tempo_events(track) | |
self.events = self._parse_note_events(track) | |
self.ticks_per_beat = ticks_per_beat | |
def _parse_tempo_events(self, track): | |
tempo_events = [] | |
current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat | |
for msg in track: | |
if msg.type == 'set_tempo': | |
tempo_events.append((msg.time, msg.tempo)) | |
elif not msg.is_meta: | |
tempo_events.append((msg.time, current_tempo)) | |
return tempo_events | |
def _parse_note_events(self, track): | |
events = [] | |
start_time = 0 | |
for msg in track: | |
if not msg.is_meta: | |
start_time += msg.time | |
if msg.type == 'note_on' and msg.velocity > 0: | |
note_on_time = start_time | |
elif msg.type == 'note_on' and msg.velocity == 0: | |
duration = start_time - note_on_time | |
events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration)) | |
return events | |
def synthesize_track(self, diffSynthSampler, sample_rate=16000): | |
track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32) | |
current_tempo = 500000 # Start with default MIDI tempo 120 BPM | |
duration_note_mapping = {} | |
for event in tqdm(self.events[:25]): | |
current_tempo = self._get_tempo_at(event.start_time) | |
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) | |
start_time_sec = event.start_time * seconds_per_tick | |
# Todo: set a minimum duration | |
duration_sec = event.duration * seconds_per_tick | |
duration_sec = max(duration_sec, 0.75) | |
start_sample = int(start_time_sec * sample_rate) | |
if not (str(duration_sec) in duration_note_mapping): | |
note_sample = diffSynthSampler(event.velocity, duration_sec) | |
duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample)) | |
# note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) | |
# note_audio = pitch_shift_audio(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) | |
note_audio = pitch_shift_librosa(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52) | |
end_sample = start_sample + len(note_audio) | |
track_audio[start_sample:end_sample] += note_audio | |
return track_audio | |
def _get_tempo_at(self, time_tick): | |
current_tempo = 500000 # Start with default MIDI tempo 120 BPM | |
elapsed_ticks = 0 | |
for tempo_change in self.tempo_events: | |
if elapsed_ticks + tempo_change[0] > time_tick: | |
return current_tempo | |
elapsed_ticks += tempo_change[0] | |
current_tempo = tempo_change[1] | |
return current_tempo | |
def _get_total_time(self): | |
total_time = 0 | |
current_tempo = 500000 # Start with default MIDI tempo 120 BPM | |
for event in self.events: | |
current_tempo = self._get_tempo_at(event.start_time) | |
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo) | |
total_time += event.duration * seconds_per_tick | |
return total_time | |
class DiffSynth: | |
def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device, | |
model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False): | |
self.noise_prediction_model = noise_prediction_model | |
self.VAE_quantizer = VAE_quantizer | |
self.VAE_decoder = VAE_decoder | |
self.device = device | |
self.model_sample_rate = model_sample_rate | |
self.timesteps = timesteps | |
self.channels = channels | |
self.freq_resolution = freq_resolution | |
self.time_resolution = time_resolution | |
self.height = int(freq_resolution/VAE_scale) | |
self.VAE_scale = VAE_scale | |
self.squared = squared | |
self.text_encoder = text_encoder | |
self.CLAP_tokenizer = CLAP_tokenizer | |
# instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler) | |
self.instruments_configs = instruments_configs | |
self.diffSynthSamplers = {} | |
self._update_instruments() | |
def _update_instruments(self): | |
def diffSynthSamplerWrapper(instruments_config): | |
def diffSynthSampler(velocity, duration_sec, sample_rate=16000): | |
condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device) | |
sample_steps = instruments_config['sample_steps'] | |
sampler = instruments_config['sampler'] | |
noising_strength = instruments_config['noising_strength'] | |
latent_representation = instruments_config['latent_representation'] | |
attack = instruments_config['attack'] | |
before_release = instruments_config['before_release'] | |
assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate" | |
width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale) | |
mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True) | |
mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32))) | |
# mask = 1, freeze | |
latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device) | |
latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0 | |
latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0 | |
latent_representations, _ = \ | |
mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width), | |
noising_strength=noising_strength, condition=condition, | |
guide_img=latent_representation, mask=latent_mask, return_tensor=True, | |
sampler=sampler, | |
use_dynamic_mask=True, end_noise_level_ratio=0.0, | |
mask_flexivity=1.0) | |
latent_representations = latent_representations[-1] | |
quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations) | |
# Todo: remove hard-coding | |
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder, | |
quantized_latent_representations, | |
resolution=( | |
512, | |
width * self.VAE_scale), | |
original_STFT_batch=None, | |
) | |
return rec_signals[0] | |
return diffSynthSampler | |
for key in self.instruments_configs.keys(): | |
self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key]) | |
def get_music(self, mid, instrument_names, sample_rate=16000): | |
tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks] | |
assert len(tracks) <= len(instrument_names), f"len(tracks) = {len(tracks)} > {len(instrument_names)} = len(instrument_names)" | |
track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)] | |
# 将所有音轨填充至最长音轨的长度,以便它们可以被叠加 | |
max_length = max(len(audio) for audio in track_audios) | |
full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零 | |
for audio in track_audios: | |
# 音轨可能不够长,需要填充零 | |
padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant') | |
full_audio += padded_audio # 叠加音轨 | |
return full_audio |