Spaces:
Running
Running
File size: 11,468 Bytes
bd6e54b ae1bdf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import librosa
import numpy as np
import torch
from model.DiffSynthSampler import DiffSynthSampler
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
import mido
import torchaudio.transforms as transforms
from tqdm import tqdm
# def pitch_shift_audio(waveform, sample_rate, n_steps, device='cpu', n_fft=1024, hop_length=None):
# # 如果输入是 numpy 数组,则转换为 torch.Tensor
# if isinstance(waveform, np.ndarray):
# waveform = torch.from_numpy(waveform)
#
# # 设置 hop_length 为 n_fft 的一半(合理的默认值),以减少 STFT 操作的内存开销
# if hop_length is None:
# hop_length = n_fft // 4
#
# # 将 waveform 移动到指定设备上
# waveform = waveform.to(device, dtype=torch.float32)
#
# # 创建 pitch_shift 变换,并移动到指定设备上
# pitch_shift = transforms.PitchShift(
# sample_rate=sample_rate,
# n_steps=n_steps,
# n_fft=n_fft,
# hop_length=hop_length
# ).to(device)
#
# # 执行变换,并将结果从设备移动到 CPU 上,最后转换为 numpy 数组
# shifted_waveform = pitch_shift(waveform).detach().cpu().numpy()
#
# return shifted_waveform
def pitch_shift_librosa(waveform, sample_rate, total_steps, step_size=4, n_fft=4096, hop_length=None):
# librosa 需要输入的是 numpy 数组
if isinstance(waveform, torch.Tensor):
waveform = waveform.numpy()
# 如果 hop_length 未提供,则使用 n_fft 的四分之一作为默认值
if hop_length is None:
hop_length = n_fft // 4
# 逐步进行 pitch shift,每次提升 step_size 个半音
current_waveform = waveform
num_steps = int(np.ceil(total_steps / step_size))
for i in range(num_steps):
step = min(step_size, total_steps - i * step_size) # 确保最后一步不会超过 total_steps
current_waveform = librosa.effects.pitch_shift(
current_waveform, sr=sample_rate, n_steps=step,
n_fft=n_fft, hop_length=hop_length
)
return current_waveform
class NoteEvent:
def __init__(self, note, velocity, start_time, duration):
self.note = note
self.velocity = velocity
self.start_time = start_time # In ticks
self.duration = duration # In ticks
def __str__(self):
return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"
class Track:
def __init__(self, track, ticks_per_beat):
self.tempo_events = self._parse_tempo_events(track)
self.events = self._parse_note_events(track)
self.ticks_per_beat = ticks_per_beat
def _parse_tempo_events(self, track):
tempo_events = []
current_tempo = 500000 # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
for msg in track:
if msg.type == 'set_tempo':
tempo_events.append((msg.time, msg.tempo))
elif not msg.is_meta:
tempo_events.append((msg.time, current_tempo))
return tempo_events
def _parse_note_events(self, track):
events = []
start_time = 0
for msg in track:
if not msg.is_meta:
start_time += msg.time
if msg.type == 'note_on' and msg.velocity > 0:
note_on_time = start_time
elif msg.type == 'note_on' and msg.velocity == 0:
duration = start_time - note_on_time
events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
return events
def synthesize_track(self, diffSynthSampler, sample_rate=16000):
track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
duration_note_mapping = {}
for event in tqdm(self.events[:25]):
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
start_time_sec = event.start_time * seconds_per_tick
# Todo: set a minimum duration
duration_sec = event.duration * seconds_per_tick
duration_sec = max(duration_sec, 0.75)
start_sample = int(start_time_sec * sample_rate)
if not (str(duration_sec) in duration_note_mapping):
note_sample = diffSynthSampler(event.velocity, duration_sec)
duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))
# note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
# note_audio = pitch_shift_audio(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
note_audio = pitch_shift_librosa(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
end_sample = start_sample + len(note_audio)
track_audio[start_sample:end_sample] += note_audio
return track_audio
def _get_tempo_at(self, time_tick):
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
elapsed_ticks = 0
for tempo_change in self.tempo_events:
if elapsed_ticks + tempo_change[0] > time_tick:
return current_tempo
elapsed_ticks += tempo_change[0]
current_tempo = tempo_change[1]
return current_tempo
def _get_total_time(self):
total_time = 0
current_tempo = 500000 # Start with default MIDI tempo 120 BPM
for event in self.events:
current_tempo = self._get_tempo_at(event.start_time)
seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
total_time += event.duration * seconds_per_tick
return total_time
class DiffSynth:
def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,
model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):
self.noise_prediction_model = noise_prediction_model
self.VAE_quantizer = VAE_quantizer
self.VAE_decoder = VAE_decoder
self.device = device
self.model_sample_rate = model_sample_rate
self.timesteps = timesteps
self.channels = channels
self.freq_resolution = freq_resolution
self.time_resolution = time_resolution
self.height = int(freq_resolution/VAE_scale)
self.VAE_scale = VAE_scale
self.squared = squared
self.text_encoder = text_encoder
self.CLAP_tokenizer = CLAP_tokenizer
# instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
self.instruments_configs = instruments_configs
self.diffSynthSamplers = {}
self._update_instruments()
def _update_instruments(self):
def diffSynthSamplerWrapper(instruments_config):
def diffSynthSampler(velocity, duration_sec, sample_rate=16000):
condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
sample_steps = instruments_config['sample_steps']
sampler = instruments_config['sampler']
noising_strength = instruments_config['noising_strength']
latent_representation = instruments_config['latent_representation']
attack = instruments_config['attack']
before_release = instruments_config['before_release']
assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"
width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)
mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))
# mask = 1, freeze
latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0
latent_representations, _ = \
mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
noising_strength=noising_strength, condition=condition,
guide_img=latent_representation, mask=latent_mask, return_tensor=True,
sampler=sampler,
use_dynamic_mask=True, end_noise_level_ratio=0.0,
mask_flexivity=1.0)
latent_representations = latent_representations[-1]
quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
# Todo: remove hard-coding
flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
quantized_latent_representations,
resolution=(
512,
width * self.VAE_scale),
original_STFT_batch=None,
)
return rec_signals[0]
return diffSynthSampler
for key in self.instruments_configs.keys():
self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])
def get_music(self, mid, instrument_names, sample_rate=16000):
tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
assert len(tracks) <= len(instrument_names), f"len(tracks) = {len(tracks)} > {len(instrument_names)} = len(instrument_names)"
track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]
# 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
max_length = max(len(audio) for audio in track_audios)
full_audio = np.zeros(max_length, dtype=np.float32) # 初始化全音频数组为零
for audio in track_audios:
# 音轨可能不够长,需要填充零
padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
full_audio += padded_audio # 叠加音轨
return full_audio |