File size: 12,436 Bytes
d0690fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0d4960
054a419
d0690fd
 
 
 
 
 
054a419
d0690fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import pdb
from typing import Tuple
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import argparse
import importlib
import json
import math
import multiprocessing as mp
import os
import time
from argparse import Namespace
from pathlib import Path

# monkey patch to fix issues in msaf
import scipy
import numpy as np

scipy.inf = np.inf

import librosa
import torch
from ema_pytorch import EMA
from loguru import logger
from muq import MuQ
from musicfm.model.musicfm_25hz import MusicFM25Hz
from omegaconf import OmegaConf
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from configuration_songformer import SongFormerConfig
from model_config import ModelConfig

from model import Model
from omegaconf import OmegaConf

# MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
MUSICFM_HOME_PATH = "/home/node59_tmpdata3/cbhao/SongFormer_kaiyuan_test/github_test/SongFormer/src/SongFormer/ckpts/MusicFM"

BEFORE_DOWNSAMPLING_FRAME_RATES = 25
AFTER_DOWNSAMPLING_FRAME_RATES = 8.333

DATASET_LABEL = "SongForm-HX-8Class"
DATASET_IDS = [5]

TIME_DUR = 420
INPUT_SAMPLING_RATE = 24000

from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
from postprocessing.functional import postprocess_functional_structure


def rule_post_processing(msa_list):
    if len(msa_list) <= 2:
        return msa_list

    result = msa_list.copy()

    while len(result) > 2:
        first_duration = result[1][0] - result[0][0]
        if first_duration < 1.0 and len(result) > 2:
            result[0] = (result[0][0], result[1][1])
            result = [result[0]] + result[2:]
        else:
            break

    while len(result) > 2:
        last_label_duration = result[-1][0] - result[-2][0]
        if last_label_duration < 1.0:
            result = result[:-2] + [result[-1]]
        else:
            break

    while len(result) > 2:
        if result[0][1] == result[1][1] and result[1][0] <= 10.0:
            result = [(result[0][0], result[0][1])] + result[2:]
        else:
            break

    while len(result) > 2:
        last_duration = result[-1][0] - result[-2][0]
        if result[-2][1] == result[-3][1] and last_duration <= 10.0:
            result = result[:-2] + [result[-1]]
        else:
            break

    return result


class SongFormerModel(PreTrainedModel):
    config_class = SongFormerConfig

    def __init__(self, config: SongFormerConfig):
        super().__init__(config)
        device = "cpu"
        root_dir = os.environ["SONGFORMER_LOCAL_DIR"]
        with open(os.path.join(root_dir, "muq_config2.json"), "r") as f:
            muq_config_file = OmegaConf.load(f)
        # self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter", device_map=None)
        self.muq = MuQ(muq_config_file)

        self.musicfm = MusicFM25Hz(
            is_flash=False,
            stat_path=os.path.join(root_dir, "msd_stats.json"),
            # model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
        )
        self.songformer = Model(ModelConfig())

        num_classes = config.num_classes
        dataset_id2label_mask = {}
        for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
            dataset_id2label_mask[key] = np.ones(config.num_classes, dtype=bool)
            dataset_id2label_mask[key][allowed_ids] = False

        self.num_classes = num_classes
        self.dataset_id2label_mask = dataset_id2label_mask
        self.config = config

    def forward(self, input):
        with torch.no_grad():
            INPUT_SAMPLING_RATE = 24000

            device = next(self.parameters()).device
            # 如果为tensor或者是numpy
            if isinstance(input, (torch.Tensor, np.ndarray)):
                audio = torch.tensor(input).to(device)
            elif os.path.exists(input):
                wav, sr = librosa.load(input, sr=INPUT_SAMPLING_RATE)
                audio = torch.tensor(wav).to(device)
            else:
                raise ValueError("input should be a tensor/numpy or a valid file path")

            win_size = self.config.win_size
            hop_size = self.config.hop_size
            num_classes = self.config.num_classes
            total_len = (
                (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR
            ) * TIME_DUR + TIME_DUR
            total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)

            logits = {
                "function_logits": np.zeros([total_frames, num_classes]),
                "boundary_logits": np.zeros([total_frames]),
            }
            logits_num = {
                "function_logits": np.zeros([total_frames, num_classes]),
                "boundary_logits": np.zeros([total_frames]),
            }

            lens = 0
            i = 0
            while True:
                start_idx = i * INPUT_SAMPLING_RATE
                end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
                if start_idx >= audio.shape[-1]:
                    break
                if end_idx - start_idx <= 1024:
                    continue
                audio_seg = audio[start_idx:end_idx]

                # MuQ embedding
                muq_output = self.muq(audio_seg.unsqueeze(0), output_hidden_states=True)
                muq_embd_420s = muq_output["hidden_states"][10]
                del muq_output
                torch.cuda.empty_cache()

                # MusicFM embedding
                _, musicfm_hidden_states = self.musicfm.get_predictions(
                    audio_seg.unsqueeze(0)
                )
                musicfm_embd_420s = musicfm_hidden_states[10]
                del musicfm_hidden_states
                torch.cuda.empty_cache()

                wraped_muq_embd_30s = []
                wraped_musicfm_embd_30s = []

                for idx_30s in range(i, i + hop_size, 30):
                    start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
                    end_idx_30s = min(
                        (idx_30s + 30) * INPUT_SAMPLING_RATE,
                        audio.shape[-1],
                        (i + hop_size) * INPUT_SAMPLING_RATE,
                    )
                    if start_idx_30s >= audio.shape[-1]:
                        break
                    if end_idx_30s - start_idx_30s <= 1024:
                        continue
                    wraped_muq_embd_30s.append(
                        self.muq(
                            audio[start_idx_30s:end_idx_30s].unsqueeze(0),
                            output_hidden_states=True,
                        )["hidden_states"][10]
                    )
                    torch.cuda.empty_cache()
                    wraped_musicfm_embd_30s.append(
                        self.musicfm.get_predictions(
                            audio[start_idx_30s:end_idx_30s].unsqueeze(0)
                        )[1][10]
                    )
                    torch.cuda.empty_cache()

                wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
                wraped_musicfm_embd_30s = torch.concatenate(
                    wraped_musicfm_embd_30s, dim=1
                )
                all_embds = [
                    wraped_musicfm_embd_30s,
                    wraped_muq_embd_30s,
                    musicfm_embd_420s,
                    muq_embd_420s,
                ]

                if len(all_embds) > 1:
                    embd_lens = [x.shape[1] for x in all_embds]
                    max_embd_len = max(embd_lens)
                    min_embd_len = min(embd_lens)
                    if abs(max_embd_len - min_embd_len) > 4:
                        raise ValueError(
                            f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}"
                        )

                    for idx in range(len(all_embds)):
                        all_embds[idx] = all_embds[idx][:, :min_embd_len, :]

                embd = torch.concatenate(all_embds, axis=-1)

                dataset_label = DATASET_LABEL
                dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
                msa_info, chunk_logits = self.songformer.infer(
                    input_embeddings=embd,
                    dataset_ids=dataset_ids,
                    label_id_masks=torch.Tensor(
                        self.dataset_id2label_mask[
                            DATASET_LABEL_TO_DATASET_ID[dataset_label]
                        ]
                    )
                    .to(device, dtype=bool)
                    .unsqueeze(0)
                    .unsqueeze(0),
                    with_logits=True,
                )

                start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
                end_frame = start_frame + min(
                    math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
                    chunk_logits["boundary_logits"][0].shape[0],
                )

                logits["function_logits"][start_frame:end_frame, :] += (
                    chunk_logits["function_logits"][0].detach().cpu().numpy()
                )
                logits["boundary_logits"][start_frame:end_frame] = (
                    chunk_logits["boundary_logits"][0].detach().cpu().numpy()
                )
                logits_num["function_logits"][start_frame:end_frame, :] += 1
                logits_num["boundary_logits"][start_frame:end_frame] += 1
                lens += end_frame - start_frame

                i += hop_size
            logits["function_logits"] /= logits_num["function_logits"]
            logits["boundary_logits"] /= logits_num["boundary_logits"]

            logits["function_logits"] = torch.from_numpy(
                logits["function_logits"][:lens]
            ).unsqueeze(0)
            logits["boundary_logits"] = torch.from_numpy(
                logits["boundary_logits"][:lens]
            ).unsqueeze(0)

            msa_infer_output = postprocess_functional_structure(logits, self.config)

            assert msa_infer_output[-1][-1] == "end"
            if not self.config.no_rule_post_processing:
                msa_infer_output = rule_post_processing(msa_infer_output)
            msa_json = []
            for idx in range(len(msa_infer_output) - 1):
                msa_json.append(
                    {
                        "label": msa_infer_output[idx][1],
                        "start": msa_infer_output[idx][0],
                        "end": msa_infer_output[idx + 1][0],
                    }
                )
            return msa_json

    @staticmethod
    def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
        """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""

        # ---- begin: ignore muq ----
        if key.startswith("muq."):
            return key, False
        # ---- end ---

        # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
        # This rename is logged.
        if key.endswith("LayerNorm.beta"):
            return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
        if key.endswith("LayerNorm.gamma"):
            return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True

        # Rename weight norm parametrizations to match changes across torch versions.
        # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
        # This rename is not logged.
        if hasattr(nn.utils.parametrizations, "weight_norm"):
            if key.endswith("weight_g"):
                return key.replace(
                    "weight_g", "parametrizations.weight.original0"
                ), True
            if key.endswith("weight_v"):
                return key.replace(
                    "weight_v", "parametrizations.weight.original1"
                ), True
        else:
            if key.endswith("parametrizations.weight.original0"):
                return key.replace(
                    "parametrizations.weight.original0", "weight_g"
                ), True
            if key.endswith("parametrizations.weight.original1"):
                return key.replace(
                    "parametrizations.weight.original1", "weight_v"
                ), True

        return key, False