WeixuanYuan's picture
Upload 70 files
bd6e54b verified
import librosa
import numpy as np
import torch
from model.DiffSynthSampler import DiffSynthSampler
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
import mido
import torchaudio.transforms as transforms
from tqdm import tqdm
# def pitch_shift_audio(waveform, sample_rate, n_steps, device='cpu', n_fft=1024, hop_length=None):
# # 如果输入是 numpy 数组,则转换为 torch.Tensor
# if isinstance(waveform, np.ndarray):
# waveform = torch.from_numpy(waveform)
#
# # 设置 hop_length 为 n_fft 的一半(合理的默认值),以减少 STFT 操作的内存开销
# if hop_length is None:
# hop_length = n_fft // 4
#
# # 将 waveform 移动到指定设备上
# waveform = waveform.to(device, dtype=torch.float32)
#
# # 创建 pitch_shift 变换,并移动到指定设备上
# pitch_shift = transforms.PitchShift(
# sample_rate=sample_rate,
# n_steps=n_steps,
# n_fft=n_fft,
# hop_length=hop_length
# ).to(device)
#
# # 执行变换,并将结果从设备移动到 CPU 上,最后转换为 numpy 数组
# shifted_waveform = pitch_shift(waveform).detach().cpu().numpy()
#
# return shifted_waveform
def pitch_shift_librosa(waveform, sample_rate, total_steps, step_size=4, n_fft=4096, hop_length=None):
# librosa 需要输入的是 numpy 数组
if isinstance(waveform, torch.Tensor):
waveform = waveform.numpy()
# 如果 hop_length 未提供,则使用 n_fft 的四分之一作为默认值
if hop_length is None:
hop_length = n_fft // 4
# 逐步进行 pitch shift,每次提升 step_size 个半音
current_waveform = waveform
num_steps = int(np.ceil(total_steps / step_size))
for i in range(num_steps):
step = min(step_size, total_steps - i * step_size) # 确保最后一步不会超过 total_steps
current_waveform = librosa.effects.pitch_shift(
current_waveform, sr=sample_rate, n_steps=step,
n_fft=n_fft, hop_length=hop_length
)
return current_waveform
class NoteEvent:
def __init__(self, note, velocity, start_time, duration):
self.note = note
self.velocity = velocity
self.start_time = start_time # In ticks
self.duration = duration # In ticks
def __str__(self):
return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"
class Track:
def __init__(self, track, ticks_per_beat):
self.tempo_events = self._parse_tempo_events(track)
self.events = self._parse_note_events(track)
self.ticks_per_beat = ticks_per_beat
def _parse_tempo_events(self, track):
tempo_events = []
current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
for msg in track:
if msg.type == 'set_tempo':
tempo_events.append((msg.time, msg.tempo))
elif not msg.is_meta:
tempo_events.append((msg.time, current_tempo))
return tempo_events
def _parse_note_events(self, track):
events = []
start_time = 0
for msg in track:
if not msg.is_meta:
start_time += msg.time
if msg.type == 'note_on' and msg.velocity > 0:
note_on_time = start_time
elif msg.type == 'note_on' and msg.velocity == 0:
duration = start_time - note_on_time
events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
return events
def synthesize_track(self, diffSynthSampler, sample_rate=16000):
track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
duration_note_mapping = {}
for event in tqdm(self.events[:25]):
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
start_time_sec = event.start_time * seconds_per_tick
# Todo: set a minimum duration
duration_sec = event.duration * seconds_per_tick
duration_sec = max(duration_sec, 0.75)
start_sample = int(start_time_sec * sample_rate)
if not (str(duration_sec) in duration_note_mapping):
note_sample = diffSynthSampler(event.velocity, duration_sec)
duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))
# note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
# note_audio = pitch_shift_audio(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
note_audio = pitch_shift_librosa(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
end_sample = start_sample + len(note_audio)
track_audio[start_sample:end_sample] += note_audio
return track_audio
def _get_tempo_at(self, time_tick):
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
elapsed_ticks = 0
for tempo_change in self.tempo_events:
if elapsed_ticks + tempo_change[0] > time_tick:
return current_tempo
elapsed_ticks += tempo_change[0]
current_tempo = tempo_change[1]
return current_tempo
def _get_total_time(self):
total_time = 0
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
for event in self.events:
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
total_time += event.duration * seconds_per_tick
return total_time
class DiffSynth:
def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,
model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):
self.noise_prediction_model = noise_prediction_model
self.VAE_quantizer = VAE_quantizer
self.VAE_decoder = VAE_decoder
self.device = device
self.model_sample_rate = model_sample_rate
self.timesteps = timesteps
self.channels = channels
self.freq_resolution = freq_resolution
self.time_resolution = time_resolution
self.height = int(freq_resolution/VAE_scale)
self.VAE_scale = VAE_scale
self.squared = squared
self.text_encoder = text_encoder
self.CLAP_tokenizer = CLAP_tokenizer
# instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
self.instruments_configs = instruments_configs
self.diffSynthSamplers = {}
self._update_instruments()
def _update_instruments(self):
def diffSynthSamplerWrapper(instruments_config):
def diffSynthSampler(velocity, duration_sec, sample_rate=16000):
condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
sample_steps = instruments_config['sample_steps']
sampler = instruments_config['sampler']
noising_strength = instruments_config['noising_strength']
latent_representation = instruments_config['latent_representation']
attack = instruments_config['attack']
before_release = instruments_config['before_release']
assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"
width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)
mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))
# mask = 1, freeze
latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0
latent_representations, _ = \
mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
noising_strength=noising_strength, condition=condition,
guide_img=latent_representation, mask=latent_mask, return_tensor=True,
sampler=sampler,
use_dynamic_mask=True, end_noise_level_ratio=0.0,
mask_flexivity=1.0)
latent_representations = latent_representations[-1]
quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
# Todo: remove hard-coding
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
quantized_latent_representations,
resolution=(
512,
width * self.VAE_scale),
original_STFT_batch=None,
)
return rec_signals[0]
return diffSynthSampler
for key in self.instruments_configs.keys():
self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])
def get_music(self, mid, instrument_names, sample_rate=16000):
tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
assert len(tracks) <= len(instrument_names), f"len(tracks) = {len(tracks)} > {len(instrument_names)} = len(instrument_names)"
track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]
# 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
max_length = max(len(audio) for audio in track_audios)
full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零
for audio in track_audios:
# 音轨可能不够长,需要填充零
padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
full_audio += padded_audio # 叠加音轨
return full_audio