File size: 11,468 Bytes
bd6e54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae1bdf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import librosa
import numpy as np
import torch

from model.DiffSynthSampler import DiffSynthSampler
from webUI.natural_language_guided.utils import encodeBatch2GradioOutput_STFT
import mido
import torchaudio.transforms as transforms
from tqdm import tqdm


# def pitch_shift_audio(waveform, sample_rate, n_steps, device='cpu', n_fft=1024, hop_length=None):
#     # 如果输入是 numpy 数组,则转换为 torch.Tensor
#     if isinstance(waveform, np.ndarray):
#         waveform = torch.from_numpy(waveform)
#
#     # 设置 hop_length 为 n_fft 的一半(合理的默认值),以减少 STFT 操作的内存开销
#     if hop_length is None:
#         hop_length = n_fft // 4
#
#     # 将 waveform 移动到指定设备上
#     waveform = waveform.to(device, dtype=torch.float32)
#
#     # 创建 pitch_shift 变换,并移动到指定设备上
#     pitch_shift = transforms.PitchShift(
#         sample_rate=sample_rate,
#         n_steps=n_steps,
#         n_fft=n_fft,
#         hop_length=hop_length
#     ).to(device)
#
#     # 执行变换,并将结果从设备移动到 CPU 上,最后转换为 numpy 数组
#     shifted_waveform = pitch_shift(waveform).detach().cpu().numpy()
#
#     return shifted_waveform


def pitch_shift_librosa(waveform, sample_rate, total_steps, step_size=4, n_fft=4096, hop_length=None):
    # librosa 需要输入的是 numpy 数组
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.numpy()

    # 如果 hop_length 未提供,则使用 n_fft 的四分之一作为默认值
    if hop_length is None:
        hop_length = n_fft // 4

    # 逐步进行 pitch shift,每次提升 step_size 个半音
    current_waveform = waveform
    num_steps = int(np.ceil(total_steps / step_size))

    for i in range(num_steps):
        step = min(step_size, total_steps - i * step_size)  # 确保最后一步不会超过 total_steps
        current_waveform = librosa.effects.pitch_shift(
            current_waveform, sr=sample_rate, n_steps=step,
            n_fft=n_fft, hop_length=hop_length
        )

    return current_waveform




class NoteEvent:
    def __init__(self, note, velocity, start_time, duration):
        self.note = note
        self.velocity = velocity
        self.start_time = start_time  # In ticks
        self.duration = duration      # In ticks

    def __str__(self):
        return f"Note {self.note}, velocity {self.velocity}, start_time {self.start_time}, duration {self.duration}"


class Track:
    def __init__(self, track, ticks_per_beat):
        self.tempo_events = self._parse_tempo_events(track)
        self.events = self._parse_note_events(track)
        self.ticks_per_beat = ticks_per_beat

    def _parse_tempo_events(self, track):
        tempo_events = []
        current_tempo = 500000  # Default MIDI tempo is 120 BPM which is 500000 microseconds per beat
        for msg in track:
            if msg.type == 'set_tempo':
                tempo_events.append((msg.time, msg.tempo))
            elif not msg.is_meta:
                tempo_events.append((msg.time, current_tempo))
        return tempo_events

    def _parse_note_events(self, track):
        events = []
        start_time = 0
        for msg in track:
            if not msg.is_meta:
                start_time += msg.time
                if msg.type == 'note_on' and msg.velocity > 0:
                    note_on_time = start_time
                elif msg.type == 'note_on' and msg.velocity == 0:
                    duration = start_time - note_on_time
                    events.append(NoteEvent(msg.note, msg.velocity, note_on_time, duration))
        return events

    def synthesize_track(self, diffSynthSampler, sample_rate=16000):
        track_audio = np.zeros(int(self._get_total_time() * sample_rate), dtype=np.float32)
        current_tempo = 500000  # Start with default MIDI tempo 120 BPM
        duration_note_mapping = {}

        for event in tqdm(self.events[:25]):
            current_tempo = self._get_tempo_at(event.start_time)
            seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
            start_time_sec = event.start_time * seconds_per_tick
            # Todo: set a minimum duration
            duration_sec = event.duration * seconds_per_tick
            duration_sec = max(duration_sec, 0.75)
            start_sample = int(start_time_sec * sample_rate)
            if not (str(duration_sec) in duration_note_mapping):
                note_sample = diffSynthSampler(event.velocity, duration_sec)
                duration_note_mapping[str(duration_sec)] = note_sample / np.max(np.abs(note_sample))

            # note_audio = pyrb.pitch_shift(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
            # note_audio = pitch_shift_audio(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
            note_audio = pitch_shift_librosa(duration_note_mapping[str(duration_sec)], sample_rate, event.note - 52)
            end_sample = start_sample + len(note_audio)
            track_audio[start_sample:end_sample] += note_audio

        return track_audio

    def _get_tempo_at(self, time_tick):
        current_tempo = 500000  # Start with default MIDI tempo 120 BPM
        elapsed_ticks = 0

        for tempo_change in self.tempo_events:
            if elapsed_ticks + tempo_change[0] > time_tick:
                return current_tempo
            elapsed_ticks += tempo_change[0]
            current_tempo = tempo_change[1]

        return current_tempo

    def _get_total_time(self):
        total_time = 0
        current_tempo = 500000  # Start with default MIDI tempo 120 BPM

        for event in self.events:
            current_tempo = self._get_tempo_at(event.start_time)
            seconds_per_tick = mido.tick2second(1, self.ticks_per_beat, current_tempo)
            total_time += event.duration * seconds_per_tick

        return total_time


class DiffSynth:
    def __init__(self, instruments_configs, noise_prediction_model, VAE_quantizer, VAE_decoder, text_encoder, CLAP_tokenizer, device,

                               model_sample_rate=16000, timesteps=1000, channels=4, freq_resolution=512, time_resolution=256, VAE_scale=4, squared=False):

        self.noise_prediction_model = noise_prediction_model
        self.VAE_quantizer = VAE_quantizer
        self.VAE_decoder = VAE_decoder
        self.device = device
        self.model_sample_rate = model_sample_rate
        self.timesteps = timesteps
        self.channels = channels
        self.freq_resolution = freq_resolution
        self.time_resolution = time_resolution
        self.height = int(freq_resolution/VAE_scale)
        self.VAE_scale = VAE_scale
        self.squared = squared
        self.text_encoder = text_encoder
        self.CLAP_tokenizer = CLAP_tokenizer

        # instruments_configs 是字典 string -> (condition, negative_condition, guidance_scale, sample_steps, seed, initial_noise, sampler)
        self.instruments_configs = instruments_configs
        self.diffSynthSamplers = {}
        self._update_instruments()


    def _update_instruments(self):

        def diffSynthSamplerWrapper(instruments_config):

            def diffSynthSampler(velocity, duration_sec, sample_rate=16000):

                condition = self.text_encoder.get_text_features(**self.CLAP_tokenizer([""], padding=True, return_tensors="pt")).to(self.device)
                sample_steps = instruments_config['sample_steps']
                sampler = instruments_config['sampler']
                noising_strength = instruments_config['noising_strength']
                latent_representation = instruments_config['latent_representation']
                attack = instruments_config['attack']
                before_release = instruments_config['before_release']

                assert sample_rate == self.model_sample_rate, "sample_rate != model_sample_rate"

                width = int(self.time_resolution * ((duration_sec + 1) / 4) / self.VAE_scale)

                mySampler = DiffSynthSampler(self.timesteps, height=128, channels=4, noise_strategy="repeat", mute=True)
                mySampler.respace(list(np.linspace(0, self.timesteps - 1, sample_steps, dtype=np.int32)))

                # mask = 1, freeze
                latent_mask = torch.zeros((1, 1, self.height, width), dtype=torch.float32).to(self.device)
                latent_mask[:, :, :, :int(self.time_resolution * (attack / 4) / self.VAE_scale)] = 1.0
                latent_mask[:, :, :, -int(self.time_resolution * ((before_release+1) / 4) / self.VAE_scale):] = 1.0

                latent_representations, _ = \
                    mySampler.inpaint_sample(model=self.noise_prediction_model, shape=(1, self.channels, self.height, width),
                                            noising_strength=noising_strength, condition=condition,
                                            guide_img=latent_representation, mask=latent_mask, return_tensor=True,
                                            sampler=sampler,
                                            use_dynamic_mask=True, end_noise_level_ratio=0.0,
                                            mask_flexivity=1.0)


                latent_representations = latent_representations[-1]

                quantized_latent_representations, _, (_, _, _) = self.VAE_quantizer(latent_representations)
                # Todo: remove hard-coding

                flipped_log_spectrums, flipped_phases, rec_signals, _, _, _ = encodeBatch2GradioOutput_STFT(self.VAE_decoder,
                                                                  quantized_latent_representations,
                                                                  resolution=(
                                                                      512,
                                                                      width * self.VAE_scale),
                                                                  original_STFT_batch=None,
                                                            )


                return rec_signals[0]

            return diffSynthSampler

        for key in self.instruments_configs.keys():
            self.diffSynthSamplers[key] = diffSynthSamplerWrapper(self.instruments_configs[key])

    def get_music(self, mid, instrument_names, sample_rate=16000):
        tracks = [Track(t, mid.ticks_per_beat) for t in mid.tracks]
        assert len(tracks) <= len(instrument_names), f"len(tracks) = {len(tracks)} > {len(instrument_names)} = len(instrument_names)"

        track_audios = [track.synthesize_track(self.diffSynthSamplers[instrument_names[i]], sample_rate=sample_rate) for i, track in enumerate(tracks)]

        # 将所有音轨填充至最长音轨的长度,以便它们可以被叠加
        max_length = max(len(audio) for audio in track_audios)
        full_audio = np.zeros(max_length, dtype=np.float32)  # 初始化全音频数组为零
        for audio in track_audios:
            # 音轨可能不够长,需要填充零
            padded_audio = np.pad(audio, (0, max_length - len(audio)), 'constant')
            full_audio += padded_audio  # 叠加音轨

        return full_audio