add one-click func
Browse files- config.json +10 -0
- configuration_songformer.py +24 -0
- dataset/custom_types.py +14 -0
- dataset/label2id.py +163 -0
- model.py +527 -0
- model.safetensors +3 -0
- model_config.py +65 -0
- modeling_songformer.py +328 -0
- msd_stats.json +65 -0
- muq_config2.json +143 -0
- musicfm/.gitignore +10 -0
- musicfm/LICENSE +224 -0
- musicfm/README.md +173 -0
- musicfm/data/.gitkeep +0 -0
- musicfm/figs/Fig1.png +3 -0
- musicfm/figs/Table1.png +3 -0
- musicfm/model/__init__.py +2 -0
- musicfm/model/musicfm_25hz.py +252 -0
- musicfm/modules/__init__.py +2 -0
- musicfm/modules/conv.py +82 -0
- musicfm/modules/features.py +45 -0
- musicfm/modules/flash_conformer.py +2114 -0
- musicfm/modules/random_quantizer.py +83 -0
- postprocessing/functional.py +71 -0
- postprocessing/helpers.py +101 -0
config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"SongFormerModel"
|
| 4 |
+
],
|
| 5 |
+
"model_type": "songformer",
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_songformer.SongFormerConfig",
|
| 8 |
+
"AutoModel": "modeling_songformer.SongFormerModel"
|
| 9 |
+
}
|
| 10 |
+
}
|
configuration_songformer.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class SongFormerConfig(PretrainedConfig):
|
| 4 |
+
"""Configuration class to store the configuration of a custom model."""
|
| 5 |
+
|
| 6 |
+
model_type = "custom_model"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
win_size=420,
|
| 11 |
+
hop_size=420,
|
| 12 |
+
num_classes=128,
|
| 13 |
+
no_rule_post_processing=False,
|
| 14 |
+
local_maxima_filter_size=3,
|
| 15 |
+
frame_rates=8.333,
|
| 16 |
+
**kwargs
|
| 17 |
+
):
|
| 18 |
+
super().__init__(**kwargs)
|
| 19 |
+
self.win_size = win_size
|
| 20 |
+
self.hop_size = hop_size
|
| 21 |
+
self.num_classes = num_classes
|
| 22 |
+
self.no_rule_post_processing = no_rule_post_processing
|
| 23 |
+
self.local_maxima_filter_size = local_maxima_filter_size
|
| 24 |
+
self.frame_rates = frame_rates
|
dataset/custom_types.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MsaInfo
|
| 3 |
+
A list of (timestamp, label) tuples used to represent music structure
|
| 4 |
+
analysis (MSA). The first element of the tuple is a float timestamp
|
| 5 |
+
(in seconds) and the second is a string label
|
| 6 |
+
|
| 7 |
+
Example
|
| 8 |
+
-------
|
| 9 |
+
>>> msa: MsaInfo = [(0.0, "intro"), (12.5, "verse"), (34.0, "chorus")]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import List, Tuple
|
| 13 |
+
|
| 14 |
+
MsaInfo = List[Tuple[float, str]]
|
dataset/label2id.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LABEL_TO_ID = {
|
| 2 |
+
"intro": 0,
|
| 3 |
+
"verse": 1,
|
| 4 |
+
"chorus": 2,
|
| 5 |
+
"bridge": 3,
|
| 6 |
+
"inst": 4,
|
| 7 |
+
"outro": 5,
|
| 8 |
+
"silence": 6,
|
| 9 |
+
"intchorus": 7,
|
| 10 |
+
"prechorus": 8,
|
| 11 |
+
"gtrbreak": 9,
|
| 12 |
+
"solo": 10,
|
| 13 |
+
"quietchorus": 11,
|
| 14 |
+
"bre": 12,
|
| 15 |
+
"break": 13,
|
| 16 |
+
"introverse": 14,
|
| 17 |
+
"mainriff": 15,
|
| 18 |
+
"chorushalf": 16,
|
| 19 |
+
"instintro": 17,
|
| 20 |
+
"gtr": 18,
|
| 21 |
+
"vocaloutro": 19,
|
| 22 |
+
"verse_slow": 20,
|
| 23 |
+
"fadein": 21,
|
| 24 |
+
"saxobeat": 22,
|
| 25 |
+
"transition": 23,
|
| 26 |
+
"verse1a": 24,
|
| 27 |
+
"build": 25,
|
| 28 |
+
"pre-chorus": 26,
|
| 29 |
+
"outroa": 27,
|
| 30 |
+
"bigoutro": 28,
|
| 31 |
+
"fast": 29,
|
| 32 |
+
"instrumentalverse": 30,
|
| 33 |
+
"section": 31,
|
| 34 |
+
"choruspart": 32,
|
| 35 |
+
"instbridge": 33,
|
| 36 |
+
"guitar": 34,
|
| 37 |
+
"instrumental": 35,
|
| 38 |
+
"breakdown": 36,
|
| 39 |
+
"rhythmlessintro": 37,
|
| 40 |
+
"intropt": 38,
|
| 41 |
+
"interlude": 39,
|
| 42 |
+
"postchorus": 40,
|
| 43 |
+
"postverse": 41,
|
| 44 |
+
"opening": 42,
|
| 45 |
+
"altchorus": 43,
|
| 46 |
+
"stutter": 44,
|
| 47 |
+
"oddriff": 45,
|
| 48 |
+
"synth": 46,
|
| 49 |
+
"preverse": 47,
|
| 50 |
+
"quiet": 48,
|
| 51 |
+
"raps": 49,
|
| 52 |
+
"verseinst": 50,
|
| 53 |
+
"instchorus": 51,
|
| 54 |
+
"chorus_instrumental": 52,
|
| 55 |
+
"slowverse": 53,
|
| 56 |
+
"slow": 54,
|
| 57 |
+
"worstthingever": 55,
|
| 58 |
+
"transition2a": 56,
|
| 59 |
+
"miniverse": 57,
|
| 60 |
+
"refrain": 58,
|
| 61 |
+
"introchorus": 59,
|
| 62 |
+
"drumroll": 60,
|
| 63 |
+
"guitarsolo": 61,
|
| 64 |
+
"versepart": 62,
|
| 65 |
+
"chorusinst": 63,
|
| 66 |
+
"ending": 64,
|
| 67 |
+
"no-vocal-intro": 65,
|
| 68 |
+
"no-vocal-interlude": 66,
|
| 69 |
+
"no-vocal-outro": 67,
|
| 70 |
+
"NO_LABEL": 68, # Only referring to cases without labels, this portion of labels will be ignored during the loss calculation process.
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}
|
| 74 |
+
|
| 75 |
+
# Reserve 64 embedding positions for dataset identifiers in the model.
|
| 76 |
+
DATASET_LABEL_TO_DATASET_ID = {
|
| 77 |
+
"SongForm-HX-7Class": 0, # Categories after rule mapping for HarmonixSet
|
| 78 |
+
"SongForm-HX-Widen": 1, # Original HarmonixSet
|
| 79 |
+
"SongForm-Private-Raw": 2,
|
| 80 |
+
"SongForm-Private": 3,
|
| 81 |
+
"SongForm-HX-Gemini-Relabeled": 4, # Rule-mapped HarmonixSet corrected by Gemini
|
| 82 |
+
"SongForm-HX-8Class": 5, # Rule-mapped (pre-chorus retained)
|
| 83 |
+
"SongForm-Hook": 6,
|
| 84 |
+
"SongForm-Gem": 7,
|
| 85 |
+
"SongForm-Gem-Only-Label": 8, # Use only segments with labels in SongForm-Gem
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
DATASET_ID_TO_DATASET_LABEL = {v: k for k, v in DATASET_LABEL_TO_DATASET_ID.items()}
|
| 89 |
+
|
| 90 |
+
DATASET_ID_ALLOWED_LABEL_IDS = {
|
| 91 |
+
0: [0, 1, 2, 3, 4, 5, 6],
|
| 92 |
+
1: [
|
| 93 |
+
0,
|
| 94 |
+
1,
|
| 95 |
+
2,
|
| 96 |
+
3,
|
| 97 |
+
4,
|
| 98 |
+
5,
|
| 99 |
+
6,
|
| 100 |
+
7,
|
| 101 |
+
8,
|
| 102 |
+
9,
|
| 103 |
+
10,
|
| 104 |
+
11,
|
| 105 |
+
12,
|
| 106 |
+
13,
|
| 107 |
+
14,
|
| 108 |
+
15,
|
| 109 |
+
16,
|
| 110 |
+
17,
|
| 111 |
+
18,
|
| 112 |
+
19,
|
| 113 |
+
20,
|
| 114 |
+
21,
|
| 115 |
+
22,
|
| 116 |
+
23,
|
| 117 |
+
24,
|
| 118 |
+
25,
|
| 119 |
+
27,
|
| 120 |
+
28,
|
| 121 |
+
29,
|
| 122 |
+
30,
|
| 123 |
+
31,
|
| 124 |
+
32,
|
| 125 |
+
33,
|
| 126 |
+
34,
|
| 127 |
+
35,
|
| 128 |
+
36,
|
| 129 |
+
37,
|
| 130 |
+
38,
|
| 131 |
+
40,
|
| 132 |
+
41,
|
| 133 |
+
42,
|
| 134 |
+
43,
|
| 135 |
+
44,
|
| 136 |
+
45,
|
| 137 |
+
46,
|
| 138 |
+
47,
|
| 139 |
+
48,
|
| 140 |
+
49,
|
| 141 |
+
50,
|
| 142 |
+
51,
|
| 143 |
+
52,
|
| 144 |
+
53,
|
| 145 |
+
54,
|
| 146 |
+
55,
|
| 147 |
+
56,
|
| 148 |
+
57,
|
| 149 |
+
58,
|
| 150 |
+
59,
|
| 151 |
+
60,
|
| 152 |
+
61,
|
| 153 |
+
62,
|
| 154 |
+
63,
|
| 155 |
+
],
|
| 156 |
+
2: [0, 1, 2, 3, 26, 39, 64, 65, 66, 67],
|
| 157 |
+
3: [0, 1, 2, 3, 4, 5, 6, 26, 39, 64, 65, 66, 67],
|
| 158 |
+
4: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 159 |
+
5: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 160 |
+
6: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 161 |
+
7: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 162 |
+
8: [0, 1, 2, 3, 4, 5, 6, 26],
|
| 163 |
+
}
|
model.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdb
|
| 2 |
+
import scipy
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
scipy.inf = np.inf
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from dataset.custom_types import MsaInfo
|
| 12 |
+
from msaf.eval import compute_results
|
| 13 |
+
from postprocessing.functional import postprocess_functional_structure
|
| 14 |
+
from x_transformers import Encoder
|
| 15 |
+
import bisect
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Head(nn.Module):
|
| 19 |
+
def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"):
|
| 20 |
+
super().__init__()
|
| 21 |
+
hidden_dims = hidden_dims or []
|
| 22 |
+
act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU}
|
| 23 |
+
act_layer = act_layers.get(activation.lower())
|
| 24 |
+
if not act_layer:
|
| 25 |
+
raise ValueError(f"Unsupported activation: {activation}")
|
| 26 |
+
|
| 27 |
+
dims = [input_dim] + hidden_dims + [output_dim]
|
| 28 |
+
layers = []
|
| 29 |
+
for i in range(len(dims) - 1):
|
| 30 |
+
layers.append(nn.Linear(dims[i], dims[i + 1]))
|
| 31 |
+
if i < len(dims) - 2:
|
| 32 |
+
layers.append(act_layer())
|
| 33 |
+
self.net = nn.Sequential(*layers)
|
| 34 |
+
|
| 35 |
+
def reset_parameters(self, confidence):
|
| 36 |
+
bias_value = -torch.log(torch.tensor((1 - confidence) / confidence))
|
| 37 |
+
self.net[-1].bias.data.fill_(bias_value.item())
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
batch, T, C = x.shape
|
| 41 |
+
x = x.reshape(-1, C)
|
| 42 |
+
x = self.net(x)
|
| 43 |
+
return x.reshape(batch, T, -1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class WrapedTransformerEncoder(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.input_dim = input_dim
|
| 52 |
+
self.transformer_input_dim = transformer_input_dim
|
| 53 |
+
|
| 54 |
+
if input_dim != transformer_input_dim:
|
| 55 |
+
self.input_proj = nn.Sequential(
|
| 56 |
+
nn.Linear(input_dim, transformer_input_dim),
|
| 57 |
+
nn.LayerNorm(transformer_input_dim),
|
| 58 |
+
nn.GELU(),
|
| 59 |
+
nn.Dropout(dropout * 0.5),
|
| 60 |
+
nn.Linear(transformer_input_dim, transformer_input_dim),
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
self.input_proj = nn.Identity()
|
| 64 |
+
|
| 65 |
+
self.transformer = Encoder(
|
| 66 |
+
dim=transformer_input_dim,
|
| 67 |
+
depth=num_layers,
|
| 68 |
+
heads=nhead,
|
| 69 |
+
layer_dropout=dropout,
|
| 70 |
+
attn_dropout=dropout,
|
| 71 |
+
ff_dropout=dropout,
|
| 72 |
+
attn_flash=True,
|
| 73 |
+
rotary_pos_emb=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, x, src_key_padding_mask=None):
|
| 77 |
+
"""
|
| 78 |
+
The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions.
|
| 79 |
+
However, in x-transformers, False indicates masked positions.
|
| 80 |
+
Therefore, it needs to be converted so that False represents masked positions.
|
| 81 |
+
"""
|
| 82 |
+
x = self.input_proj(x)
|
| 83 |
+
mask = (
|
| 84 |
+
~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device)
|
| 85 |
+
if src_key_padding_mask is not None
|
| 86 |
+
else None
|
| 87 |
+
)
|
| 88 |
+
return self.transformer(x, mask=mask)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def prefix_dict(d, prefix: str):
|
| 92 |
+
if prefix:
|
| 93 |
+
return d
|
| 94 |
+
return {prefix + key: value for key, value in d.items()}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TimeDownsample(nn.Module):
|
| 98 |
+
def __init__(
|
| 99 |
+
self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1
|
| 100 |
+
):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.dim_out = dim_out or dim_in
|
| 103 |
+
assert self.dim_out % 2 == 0
|
| 104 |
+
|
| 105 |
+
self.depthwise_conv = nn.Conv1d(
|
| 106 |
+
in_channels=dim_in,
|
| 107 |
+
out_channels=dim_in,
|
| 108 |
+
kernel_size=kernel_size,
|
| 109 |
+
stride=stride,
|
| 110 |
+
padding=padding,
|
| 111 |
+
groups=dim_in,
|
| 112 |
+
bias=False,
|
| 113 |
+
)
|
| 114 |
+
self.pointwise_conv = nn.Conv1d(
|
| 115 |
+
in_channels=dim_in,
|
| 116 |
+
out_channels=self.dim_out,
|
| 117 |
+
kernel_size=1,
|
| 118 |
+
bias=False,
|
| 119 |
+
)
|
| 120 |
+
self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding)
|
| 121 |
+
self.norm1 = nn.LayerNorm(self.dim_out)
|
| 122 |
+
self.act1 = nn.GELU()
|
| 123 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 124 |
+
|
| 125 |
+
if dim_in != self.dim_out:
|
| 126 |
+
self.residual_conv = nn.Conv1d(
|
| 127 |
+
dim_in, self.dim_out, kernel_size=1, bias=False
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
self.residual_conv = None
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
residual = x # [B, T, D_in]
|
| 134 |
+
# Convolutional module
|
| 135 |
+
x_c = x.transpose(1, 2) # [B, D_in, T]
|
| 136 |
+
x_c = self.depthwise_conv(x_c) # [B, D_in, T_down]
|
| 137 |
+
x_c = self.pointwise_conv(x_c) # [B, D_out, T_down]
|
| 138 |
+
|
| 139 |
+
# Residual module
|
| 140 |
+
res = self.pool(residual.transpose(1, 2)) # [B, D_in, T]
|
| 141 |
+
if self.residual_conv:
|
| 142 |
+
res = self.residual_conv(res) # [B, D_out, T_down]
|
| 143 |
+
x_c = x_c + res # [B, D_out, T_down]
|
| 144 |
+
x_c = x_c.transpose(1, 2) # [B, T_down, D_out]
|
| 145 |
+
x_c = self.norm1(x_c)
|
| 146 |
+
x_c = self.act1(x_c)
|
| 147 |
+
x_c = self.dropout1(x_c)
|
| 148 |
+
return x_c
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class AddFuse(nn.Module):
|
| 152 |
+
def __init__(self):
|
| 153 |
+
super(AddFuse, self).__init__()
|
| 154 |
+
|
| 155 |
+
def forward(self, x, cond):
|
| 156 |
+
return x + cond
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class TVLoss1D(nn.Module):
|
| 160 |
+
def __init__(
|
| 161 |
+
self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Args:
|
| 165 |
+
beta: Exponential parameter for TV loss (recommended 0.5~1.0)
|
| 166 |
+
lambda_tv: Overall weight for TV loss
|
| 167 |
+
boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01)
|
| 168 |
+
reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty)
|
| 169 |
+
"""
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.beta = beta
|
| 172 |
+
self.lambda_tv = lambda_tv
|
| 173 |
+
self.boundary_threshold = boundary_threshold
|
| 174 |
+
self.reduction_weight = reduction_weight
|
| 175 |
+
|
| 176 |
+
def forward(self, pred, target=None):
|
| 177 |
+
"""
|
| 178 |
+
Args:
|
| 179 |
+
pred: (B, T) or (B, T, 1), float boundary scores output by the model
|
| 180 |
+
target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided)
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
scalar: weighted TV loss
|
| 184 |
+
"""
|
| 185 |
+
if pred.dim() == 3:
|
| 186 |
+
pred = pred.squeeze(-1)
|
| 187 |
+
if target is not None and target.dim() == 3:
|
| 188 |
+
target = target.squeeze(-1)
|
| 189 |
+
|
| 190 |
+
diff = pred[:, 1:] - pred[:, :-1]
|
| 191 |
+
tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta)
|
| 192 |
+
|
| 193 |
+
if target is None:
|
| 194 |
+
return self.lambda_tv * tv_base.mean()
|
| 195 |
+
|
| 196 |
+
left_in_boundary = target[:, :-1] > self.boundary_threshold
|
| 197 |
+
right_in_boundary = target[:, 1:] > self.boundary_threshold
|
| 198 |
+
near_boundary = left_in_boundary | right_in_boundary
|
| 199 |
+
weight_mask = torch.where(
|
| 200 |
+
near_boundary,
|
| 201 |
+
self.reduction_weight * torch.ones_like(tv_base),
|
| 202 |
+
torch.ones_like(tv_base),
|
| 203 |
+
)
|
| 204 |
+
tv_weighted = (tv_base * weight_mask).mean()
|
| 205 |
+
return self.lambda_tv * tv_weighted
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class SoftmaxFocalLoss(nn.Module):
|
| 209 |
+
"""
|
| 210 |
+
Softmax Focal Loss for single-label multi-class classification.
|
| 211 |
+
Suitable for mutually exclusive classes.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.alpha = alpha
|
| 217 |
+
self.gamma = gamma
|
| 218 |
+
|
| 219 |
+
def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 220 |
+
"""
|
| 221 |
+
Args:
|
| 222 |
+
pred: [B, T, C], raw logits
|
| 223 |
+
targets: [B, T, C] (soft) or [B, T] (hard, dtype=long)
|
| 224 |
+
Returns:
|
| 225 |
+
loss: scalar or [B, T] depending on reduction
|
| 226 |
+
"""
|
| 227 |
+
log_probs = F.log_softmax(pred, dim=-1)
|
| 228 |
+
probs = torch.exp(log_probs)
|
| 229 |
+
|
| 230 |
+
if targets.dtype == torch.long:
|
| 231 |
+
targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float()
|
| 232 |
+
else:
|
| 233 |
+
targets_onehot = targets
|
| 234 |
+
|
| 235 |
+
p_t = (probs * targets_onehot).sum(dim=-1)
|
| 236 |
+
p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8)
|
| 237 |
+
|
| 238 |
+
if self.alpha > 0:
|
| 239 |
+
alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * (
|
| 240 |
+
1 - targets_onehot
|
| 241 |
+
)
|
| 242 |
+
alpha_weight = (alpha_t * targets_onehot).sum(dim=-1)
|
| 243 |
+
else:
|
| 244 |
+
alpha_weight = 1.0
|
| 245 |
+
|
| 246 |
+
focal_weight = (1 - p_t) ** self.gamma
|
| 247 |
+
ce_loss = -log_probs * targets_onehot
|
| 248 |
+
ce_loss = ce_loss.sum(dim=-1)
|
| 249 |
+
|
| 250 |
+
loss = alpha_weight * focal_weight * ce_loss
|
| 251 |
+
return loss
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class Model(nn.Module):
|
| 255 |
+
def __init__(self, config):
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.config = config
|
| 258 |
+
|
| 259 |
+
self.input_norm = nn.LayerNorm(config.input_dim)
|
| 260 |
+
self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim)
|
| 261 |
+
self.dataset_class_prefix = nn.Embedding(
|
| 262 |
+
num_embeddings=config.num_dataset_classes,
|
| 263 |
+
embedding_dim=config.transformer_encoder_input_dim,
|
| 264 |
+
)
|
| 265 |
+
self.down_sample_conv = TimeDownsample(
|
| 266 |
+
dim_in=config.input_dim,
|
| 267 |
+
dim_out=config.transformer_encoder_input_dim,
|
| 268 |
+
kernel_size=config.down_sample_conv_kernel_size,
|
| 269 |
+
stride=config.down_sample_conv_stride,
|
| 270 |
+
dropout=config.down_sample_conv_dropout,
|
| 271 |
+
padding=config.down_sample_conv_padding,
|
| 272 |
+
)
|
| 273 |
+
self.AddFuse = AddFuse()
|
| 274 |
+
self.transformer = WrapedTransformerEncoder(
|
| 275 |
+
input_dim=config.transformer_encoder_input_dim,
|
| 276 |
+
transformer_input_dim=config.transformer_input_dim,
|
| 277 |
+
num_layers=config.num_transformer_layers,
|
| 278 |
+
nhead=config.transformer_nhead,
|
| 279 |
+
dropout=config.transformer_dropout,
|
| 280 |
+
)
|
| 281 |
+
self.boundary_TVLoss1D = TVLoss1D(
|
| 282 |
+
beta=config.boundary_tv_loss_beta,
|
| 283 |
+
lambda_tv=config.boundary_tv_loss_lambda,
|
| 284 |
+
boundary_threshold=config.boundary_tv_loss_boundary_threshold,
|
| 285 |
+
reduction_weight=config.boundary_tv_loss_reduction_weight,
|
| 286 |
+
)
|
| 287 |
+
self.label_focal_loss = SoftmaxFocalLoss(
|
| 288 |
+
alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma
|
| 289 |
+
)
|
| 290 |
+
self.boundary_head = Head(config.transformer_input_dim, 1)
|
| 291 |
+
self.function_head = Head(config.transformer_input_dim, config.num_classes)
|
| 292 |
+
|
| 293 |
+
def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo):
|
| 294 |
+
assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", (
|
| 295 |
+
"gt_info and msa_info should end with 'end'"
|
| 296 |
+
)
|
| 297 |
+
gt_info_labels = [label for time_, label in gt_info][:-1]
|
| 298 |
+
gt_info_inters = [time_ for time_, label in gt_info]
|
| 299 |
+
gt_info_inters = np.column_stack(
|
| 300 |
+
[np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])]
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
msa_info_labels = [label for time_, label in msa_info][:-1]
|
| 304 |
+
msa_info_inters = [time_ for time_, label in msa_info]
|
| 305 |
+
msa_info_inters = np.column_stack(
|
| 306 |
+
[np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])]
|
| 307 |
+
)
|
| 308 |
+
result = compute_results(
|
| 309 |
+
ann_inter=gt_info_inters,
|
| 310 |
+
est_inter=msa_info_inters,
|
| 311 |
+
ann_labels=gt_info_labels,
|
| 312 |
+
est_labels=msa_info_labels,
|
| 313 |
+
bins=11,
|
| 314 |
+
est_file="test.txt",
|
| 315 |
+
weight=0.58,
|
| 316 |
+
)
|
| 317 |
+
return result
|
| 318 |
+
|
| 319 |
+
def cal_acc(
|
| 320 |
+
self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3
|
| 321 |
+
):
|
| 322 |
+
ann_info_time = [
|
| 323 |
+
int(round(time_, post_digit) * (10**post_digit))
|
| 324 |
+
for time_, label in ann_info
|
| 325 |
+
]
|
| 326 |
+
est_info_time = [
|
| 327 |
+
int(round(time_, post_digit) * (10**post_digit))
|
| 328 |
+
for time_, label in est_info
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
common_start_time = max(ann_info_time[0], est_info_time[0])
|
| 332 |
+
common_end_time = min(ann_info_time[-1], est_info_time[-1])
|
| 333 |
+
|
| 334 |
+
time_points = {common_start_time, common_end_time}
|
| 335 |
+
time_points.update(
|
| 336 |
+
{
|
| 337 |
+
time_
|
| 338 |
+
for time_ in ann_info_time
|
| 339 |
+
if common_start_time <= time_ <= common_end_time
|
| 340 |
+
}
|
| 341 |
+
)
|
| 342 |
+
time_points.update(
|
| 343 |
+
{
|
| 344 |
+
time_
|
| 345 |
+
for time_ in est_info_time
|
| 346 |
+
if common_start_time <= time_ <= common_end_time
|
| 347 |
+
}
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
time_points = sorted(time_points)
|
| 351 |
+
total_duration, total_score = 0, 0
|
| 352 |
+
|
| 353 |
+
for idx in range(len(time_points) - 1):
|
| 354 |
+
duration = time_points[idx + 1] - time_points[idx]
|
| 355 |
+
ann_label = ann_info[
|
| 356 |
+
bisect.bisect_right(ann_info_time, time_points[idx]) - 1
|
| 357 |
+
][1]
|
| 358 |
+
est_label = est_info[
|
| 359 |
+
bisect.bisect_right(est_info_time, time_points[idx]) - 1
|
| 360 |
+
][1]
|
| 361 |
+
total_duration += duration
|
| 362 |
+
if ann_label == est_label:
|
| 363 |
+
total_score += duration
|
| 364 |
+
return total_score / total_duration
|
| 365 |
+
|
| 366 |
+
def infer_with_metrics(self, batch, prefix: str = None):
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
logits = self.forward_func(batch)
|
| 369 |
+
|
| 370 |
+
losses = self.compute_losses(logits, batch, prefix=None)
|
| 371 |
+
|
| 372 |
+
expanded_mask = batch["label_id_masks"].expand(
|
| 373 |
+
-1, logits["function_logits"].size(1), -1
|
| 374 |
+
)
|
| 375 |
+
logits["function_logits"] = logits["function_logits"].masked_fill(
|
| 376 |
+
expanded_mask, -float("inf")
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
msa_info = postprocess_functional_structure(
|
| 380 |
+
logits=logits, config=self.config
|
| 381 |
+
)
|
| 382 |
+
gt_info = batch["msa_infos"][0]
|
| 383 |
+
results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info)
|
| 384 |
+
|
| 385 |
+
ret_results = {
|
| 386 |
+
"loss": losses["loss"].item(),
|
| 387 |
+
"HitRate_3P": results["HitRate_3P"],
|
| 388 |
+
"HitRate_3R": results["HitRate_3R"],
|
| 389 |
+
"HitRate_3F": results["HitRate_3F"],
|
| 390 |
+
"HitRate_0.5P": results["HitRate_0.5P"],
|
| 391 |
+
"HitRate_0.5R": results["HitRate_0.5R"],
|
| 392 |
+
"HitRate_0.5F": results["HitRate_0.5F"],
|
| 393 |
+
"PWF": results["PWF"],
|
| 394 |
+
"PWP": results["PWP"],
|
| 395 |
+
"PWR": results["PWR"],
|
| 396 |
+
"Sf": results["Sf"],
|
| 397 |
+
"So": results["So"],
|
| 398 |
+
"Su": results["Su"],
|
| 399 |
+
"acc": self.cal_acc(ann_info=gt_info, est_info=msa_info),
|
| 400 |
+
}
|
| 401 |
+
if prefix:
|
| 402 |
+
ret_results = prefix_dict(ret_results, prefix)
|
| 403 |
+
|
| 404 |
+
return ret_results
|
| 405 |
+
|
| 406 |
+
def infer(
|
| 407 |
+
self,
|
| 408 |
+
input_embeddings,
|
| 409 |
+
dataset_ids,
|
| 410 |
+
label_id_masks,
|
| 411 |
+
prefix: str = None,
|
| 412 |
+
with_logits=False,
|
| 413 |
+
):
|
| 414 |
+
with torch.no_grad():
|
| 415 |
+
input_embeddings = self.mixed_win_downsample(input_embeddings)
|
| 416 |
+
input_embeddings = self.input_norm(input_embeddings)
|
| 417 |
+
logits = self.down_sample_conv(input_embeddings)
|
| 418 |
+
|
| 419 |
+
dataset_prefix = self.dataset_class_prefix(dataset_ids)
|
| 420 |
+
dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand(
|
| 421 |
+
logits.size(0), 1, -1
|
| 422 |
+
)
|
| 423 |
+
logits = self.AddFuse(x=logits, cond=dataset_prefix_expand)
|
| 424 |
+
logits = self.transformer(x=logits, src_key_padding_mask=None)
|
| 425 |
+
|
| 426 |
+
function_logits = self.function_head(logits)
|
| 427 |
+
boundary_logits = self.boundary_head(logits).squeeze(-1)
|
| 428 |
+
|
| 429 |
+
logits = {
|
| 430 |
+
"function_logits": function_logits,
|
| 431 |
+
"boundary_logits": boundary_logits,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
expanded_mask = label_id_masks.expand(
|
| 435 |
+
-1, logits["function_logits"].size(1), -1
|
| 436 |
+
)
|
| 437 |
+
logits["function_logits"] = logits["function_logits"].masked_fill(
|
| 438 |
+
expanded_mask, -float("inf")
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
msa_info = postprocess_functional_structure(
|
| 442 |
+
logits=logits, config=self.config
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
return (msa_info, logits) if with_logits else msa_info
|
| 446 |
+
|
| 447 |
+
def compute_losses(self, outputs, batch, prefix: str = None):
|
| 448 |
+
loss = 0.0
|
| 449 |
+
losses = {}
|
| 450 |
+
|
| 451 |
+
loss_section = F.binary_cross_entropy_with_logits(
|
| 452 |
+
outputs["boundary_logits"],
|
| 453 |
+
batch["widen_true_boundaries"],
|
| 454 |
+
reduction="none",
|
| 455 |
+
)
|
| 456 |
+
loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D(
|
| 457 |
+
pred=outputs["boundary_logits"],
|
| 458 |
+
target=batch["widen_true_boundaries"],
|
| 459 |
+
)
|
| 460 |
+
loss_function = F.cross_entropy(
|
| 461 |
+
outputs["function_logits"].transpose(1, 2),
|
| 462 |
+
batch["true_functions"].transpose(1, 2),
|
| 463 |
+
reduction="none",
|
| 464 |
+
)
|
| 465 |
+
# input is [B, T, C]
|
| 466 |
+
ttt = self.config.label_focal_loss_weight * self.label_focal_loss(
|
| 467 |
+
pred=outputs["function_logits"], targets=batch["true_functions"]
|
| 468 |
+
)
|
| 469 |
+
loss_function += ttt
|
| 470 |
+
|
| 471 |
+
float_masks = (~batch["masks"]).float()
|
| 472 |
+
boundary_mask = batch.get("boundary_mask", None)
|
| 473 |
+
function_mask = batch.get("function_mask", None)
|
| 474 |
+
if boundary_mask is not None:
|
| 475 |
+
boundary_mask = (~boundary_mask).float()
|
| 476 |
+
else:
|
| 477 |
+
boundary_mask = 1
|
| 478 |
+
|
| 479 |
+
if function_mask is not None:
|
| 480 |
+
function_mask = (~function_mask).float()
|
| 481 |
+
else:
|
| 482 |
+
function_mask = 1
|
| 483 |
+
|
| 484 |
+
loss_section = torch.mean(boundary_mask * float_masks * loss_section)
|
| 485 |
+
loss_function = torch.mean(function_mask * float_masks * loss_function)
|
| 486 |
+
|
| 487 |
+
loss_section *= self.config.loss_weight_section
|
| 488 |
+
loss_function *= self.config.loss_weight_function
|
| 489 |
+
|
| 490 |
+
if self.config.learn_label:
|
| 491 |
+
loss += loss_function
|
| 492 |
+
if self.config.learn_segment:
|
| 493 |
+
loss += loss_section
|
| 494 |
+
|
| 495 |
+
losses.update(
|
| 496 |
+
loss=loss,
|
| 497 |
+
loss_section=loss_section,
|
| 498 |
+
loss_function=loss_function,
|
| 499 |
+
)
|
| 500 |
+
if prefix:
|
| 501 |
+
losses = prefix_dict(losses, prefix)
|
| 502 |
+
return losses
|
| 503 |
+
|
| 504 |
+
def forward_func(self, batch):
|
| 505 |
+
input_embeddings = batch["input_embeddings"]
|
| 506 |
+
input_embeddings = self.mixed_win_downsample(input_embeddings)
|
| 507 |
+
input_embeddings = self.input_norm(input_embeddings)
|
| 508 |
+
logits = self.down_sample_conv(input_embeddings)
|
| 509 |
+
|
| 510 |
+
dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"])
|
| 511 |
+
logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1))
|
| 512 |
+
src_key_padding_mask = batch["masks"]
|
| 513 |
+
logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask)
|
| 514 |
+
|
| 515 |
+
function_logits = self.function_head(logits)
|
| 516 |
+
boundary_logits = self.boundary_head(logits).squeeze(-1)
|
| 517 |
+
|
| 518 |
+
logits = {
|
| 519 |
+
"function_logits": function_logits,
|
| 520 |
+
"boundary_logits": boundary_logits,
|
| 521 |
+
}
|
| 522 |
+
return logits
|
| 523 |
+
|
| 524 |
+
def forward(self, batch):
|
| 525 |
+
logits = self.forward_func(batch)
|
| 526 |
+
losses = self.compute_losses(logits, batch, prefix=None)
|
| 527 |
+
return logits, losses["loss"], losses
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8dcabf4ea19973edd51b9e5794004775fa7e8de3ecfa07eb1dbce00f516ce7f7
|
| 3 |
+
size 2755035132
|
model_config.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class ModelConfig(PretrainedConfig):
|
| 5 |
+
model_type = "SongFormer"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
input_dim=2048,
|
| 10 |
+
input_dim_raw=4096,
|
| 11 |
+
transformer_encoder_input_dim=1024,
|
| 12 |
+
transformer_input_dim=512,
|
| 13 |
+
num_transformer_layers=4,
|
| 14 |
+
transformer_nhead=8,
|
| 15 |
+
transformer_dropout=0.1,
|
| 16 |
+
num_classes=128,
|
| 17 |
+
num_dataset_classes=64,
|
| 18 |
+
down_sample_conv_kernel_size=3,
|
| 19 |
+
down_sample_conv_stride=3,
|
| 20 |
+
down_sample_conv_dropout=0.1,
|
| 21 |
+
down_sample_conv_padding=0,
|
| 22 |
+
boundary_tv_loss_beta=0.6,
|
| 23 |
+
boundary_tv_loss_lambda=0.4,
|
| 24 |
+
boundary_tv_loss_boundary_threshold=0.01,
|
| 25 |
+
boundary_tv_loss_reduction_weight=0.1,
|
| 26 |
+
boundary_tvloss_weight=0.05,
|
| 27 |
+
label_focal_loss_alpha=0.25,
|
| 28 |
+
label_focal_loss_gamma=2.0,
|
| 29 |
+
label_focal_loss_weight=0.2,
|
| 30 |
+
loss_weight_section=0.2,
|
| 31 |
+
loss_weight_function=0.8,
|
| 32 |
+
learn_label=True,
|
| 33 |
+
learn_segment=True,
|
| 34 |
+
local_maxima_filter_size=3,
|
| 35 |
+
frame_rates=8.333,
|
| 36 |
+
**kwargs
|
| 37 |
+
):
|
| 38 |
+
super().__init__(**kwargs)
|
| 39 |
+
self.input_dim = input_dim
|
| 40 |
+
self.input_dim_raw = input_dim_raw
|
| 41 |
+
self.transformer_encoder_input_dim = transformer_encoder_input_dim
|
| 42 |
+
self.transformer_input_dim = transformer_input_dim
|
| 43 |
+
self.num_transformer_layers = num_transformer_layers
|
| 44 |
+
self.transformer_nhead = transformer_nhead
|
| 45 |
+
self.transformer_dropout = transformer_dropout
|
| 46 |
+
self.num_classes = num_classes
|
| 47 |
+
self.num_dataset_classes = num_dataset_classes
|
| 48 |
+
self.down_sample_conv_kernel_size = down_sample_conv_kernel_size
|
| 49 |
+
self.down_sample_conv_stride = down_sample_conv_stride
|
| 50 |
+
self.down_sample_conv_dropout = down_sample_conv_dropout
|
| 51 |
+
self.down_sample_conv_padding = down_sample_conv_padding
|
| 52 |
+
self.boundary_tv_loss_beta = boundary_tv_loss_beta
|
| 53 |
+
self.boundary_tv_loss_lambda = boundary_tv_loss_lambda
|
| 54 |
+
self.boundary_tv_loss_boundary_threshold = boundary_tv_loss_boundary_threshold
|
| 55 |
+
self.boundary_tv_loss_reduction_weight = boundary_tv_loss_reduction_weight
|
| 56 |
+
self.boundary_tvloss_weight = boundary_tvloss_weight
|
| 57 |
+
self.label_focal_loss_alpha = label_focal_loss_alpha
|
| 58 |
+
self.label_focal_loss_gamma = label_focal_loss_gamma
|
| 59 |
+
self.label_focal_loss_weight = label_focal_loss_weight
|
| 60 |
+
self.loss_weight_section = loss_weight_section
|
| 61 |
+
self.loss_weight_function = loss_weight_function
|
| 62 |
+
self.learn_label = learn_label
|
| 63 |
+
self.learn_segment = learn_segment
|
| 64 |
+
self.local_maxima_filter_size = local_maxima_filter_size
|
| 65 |
+
self.frame_rates = frame_rates
|
modeling_songformer.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdb
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import PreTrainedModel
|
| 6 |
+
import argparse
|
| 7 |
+
import importlib
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
import multiprocessing as mp
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
from argparse import Namespace
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# monkey patch to fix issues in msaf
|
| 17 |
+
import scipy
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
scipy.inf = np.inf
|
| 21 |
+
|
| 22 |
+
import librosa
|
| 23 |
+
import torch
|
| 24 |
+
from ema_pytorch import EMA
|
| 25 |
+
from loguru import logger
|
| 26 |
+
from muq import MuQ
|
| 27 |
+
from musicfm.model.musicfm_25hz import MusicFM25Hz
|
| 28 |
+
from omegaconf import OmegaConf
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
from transformers import PreTrainedModel
|
| 33 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 34 |
+
from configuration_songformer import SongFormerConfig
|
| 35 |
+
from model_config import ModelConfig
|
| 36 |
+
|
| 37 |
+
from model import Model
|
| 38 |
+
from omegaconf import OmegaConf
|
| 39 |
+
|
| 40 |
+
# MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
|
| 41 |
+
MUSICFM_HOME_PATH = "/home/node59_tmpdata3/cbhao/SongFormer_kaiyuan_test/github_test/SongFormer/src/SongFormer/ckpts/MusicFM"
|
| 42 |
+
|
| 43 |
+
BEFORE_DOWNSAMPLING_FRAME_RATES = 25
|
| 44 |
+
AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
|
| 45 |
+
|
| 46 |
+
DATASET_LABEL = "SongForm-HX-8Class"
|
| 47 |
+
DATASET_IDS = [5]
|
| 48 |
+
|
| 49 |
+
TIME_DUR = 420
|
| 50 |
+
INPUT_SAMPLING_RATE = 24000
|
| 51 |
+
|
| 52 |
+
from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
|
| 53 |
+
from postprocessing.functional import postprocess_functional_structure
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def rule_post_processing(msa_list):
|
| 57 |
+
if len(msa_list) <= 2:
|
| 58 |
+
return msa_list
|
| 59 |
+
|
| 60 |
+
result = msa_list.copy()
|
| 61 |
+
|
| 62 |
+
while len(result) > 2:
|
| 63 |
+
first_duration = result[1][0] - result[0][0]
|
| 64 |
+
if first_duration < 1.0 and len(result) > 2:
|
| 65 |
+
result[0] = (result[0][0], result[1][1])
|
| 66 |
+
result = [result[0]] + result[2:]
|
| 67 |
+
else:
|
| 68 |
+
break
|
| 69 |
+
|
| 70 |
+
while len(result) > 2:
|
| 71 |
+
last_label_duration = result[-1][0] - result[-2][0]
|
| 72 |
+
if last_label_duration < 1.0:
|
| 73 |
+
result = result[:-2] + [result[-1]]
|
| 74 |
+
else:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
while len(result) > 2:
|
| 78 |
+
if result[0][1] == result[1][1] and result[1][0] <= 10.0:
|
| 79 |
+
result = [(result[0][0], result[0][1])] + result[2:]
|
| 80 |
+
else:
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
while len(result) > 2:
|
| 84 |
+
last_duration = result[-1][0] - result[-2][0]
|
| 85 |
+
if result[-2][1] == result[-3][1] and last_duration <= 10.0:
|
| 86 |
+
result = result[:-2] + [result[-1]]
|
| 87 |
+
else:
|
| 88 |
+
break
|
| 89 |
+
|
| 90 |
+
return result
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SongFormerModel(PreTrainedModel):
|
| 94 |
+
config_class = SongFormerConfig
|
| 95 |
+
|
| 96 |
+
def __init__(self, config: SongFormerConfig):
|
| 97 |
+
super().__init__(config)
|
| 98 |
+
device = "cpu"
|
| 99 |
+
|
| 100 |
+
with open("muq_config2.json", "r") as f:
|
| 101 |
+
muq_config_file = OmegaConf.load(f)
|
| 102 |
+
# self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter", device_map=None)
|
| 103 |
+
self.muq = MuQ(muq_config_file)
|
| 104 |
+
|
| 105 |
+
self.musicfm = MusicFM25Hz(
|
| 106 |
+
is_flash=False,
|
| 107 |
+
stat_path="msd_stats.json",
|
| 108 |
+
# model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
|
| 109 |
+
)
|
| 110 |
+
self.songformer = Model(ModelConfig())
|
| 111 |
+
|
| 112 |
+
num_classes = config.num_classes
|
| 113 |
+
dataset_id2label_mask = {}
|
| 114 |
+
for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
|
| 115 |
+
dataset_id2label_mask[key] = np.ones(config.num_classes, dtype=bool)
|
| 116 |
+
dataset_id2label_mask[key][allowed_ids] = False
|
| 117 |
+
|
| 118 |
+
self.num_classes = num_classes
|
| 119 |
+
self.dataset_id2label_mask = dataset_id2label_mask
|
| 120 |
+
self.config = config
|
| 121 |
+
|
| 122 |
+
def forward(self, input):
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
INPUT_SAMPLING_RATE = 24000
|
| 125 |
+
|
| 126 |
+
device = next(self.parameters()).device
|
| 127 |
+
# 如果为tensor或者是numpy
|
| 128 |
+
if isinstance(input, (torch.Tensor, np.ndarray)):
|
| 129 |
+
audio = torch.tensor(input).to(device)
|
| 130 |
+
elif os.path.exists(input):
|
| 131 |
+
wav, sr = librosa.load(input, sr=INPUT_SAMPLING_RATE)
|
| 132 |
+
audio = torch.tensor(wav).to(device)
|
| 133 |
+
else:
|
| 134 |
+
raise ValueError("input should be a tensor/numpy or a valid file path")
|
| 135 |
+
|
| 136 |
+
win_size = self.config.win_size
|
| 137 |
+
hop_size = self.config.hop_size
|
| 138 |
+
num_classes = self.config.num_classes
|
| 139 |
+
total_len = (
|
| 140 |
+
(audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR
|
| 141 |
+
) * TIME_DUR + TIME_DUR
|
| 142 |
+
total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
|
| 143 |
+
|
| 144 |
+
logits = {
|
| 145 |
+
"function_logits": np.zeros([total_frames, num_classes]),
|
| 146 |
+
"boundary_logits": np.zeros([total_frames]),
|
| 147 |
+
}
|
| 148 |
+
logits_num = {
|
| 149 |
+
"function_logits": np.zeros([total_frames, num_classes]),
|
| 150 |
+
"boundary_logits": np.zeros([total_frames]),
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
lens = 0
|
| 154 |
+
i = 0
|
| 155 |
+
while True:
|
| 156 |
+
start_idx = i * INPUT_SAMPLING_RATE
|
| 157 |
+
end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
|
| 158 |
+
if start_idx >= audio.shape[-1]:
|
| 159 |
+
break
|
| 160 |
+
if end_idx - start_idx <= 1024:
|
| 161 |
+
continue
|
| 162 |
+
audio_seg = audio[start_idx:end_idx]
|
| 163 |
+
|
| 164 |
+
# MuQ embedding
|
| 165 |
+
muq_output = self.muq(audio_seg.unsqueeze(0), output_hidden_states=True)
|
| 166 |
+
muq_embd_420s = muq_output["hidden_states"][10]
|
| 167 |
+
del muq_output
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
|
| 170 |
+
# MusicFM embedding
|
| 171 |
+
_, musicfm_hidden_states = self.musicfm.get_predictions(
|
| 172 |
+
audio_seg.unsqueeze(0)
|
| 173 |
+
)
|
| 174 |
+
musicfm_embd_420s = musicfm_hidden_states[10]
|
| 175 |
+
del musicfm_hidden_states
|
| 176 |
+
torch.cuda.empty_cache()
|
| 177 |
+
|
| 178 |
+
wraped_muq_embd_30s = []
|
| 179 |
+
wraped_musicfm_embd_30s = []
|
| 180 |
+
|
| 181 |
+
for idx_30s in range(i, i + hop_size, 30):
|
| 182 |
+
start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
|
| 183 |
+
end_idx_30s = min(
|
| 184 |
+
(idx_30s + 30) * INPUT_SAMPLING_RATE,
|
| 185 |
+
audio.shape[-1],
|
| 186 |
+
(i + hop_size) * INPUT_SAMPLING_RATE,
|
| 187 |
+
)
|
| 188 |
+
if start_idx_30s >= audio.shape[-1]:
|
| 189 |
+
break
|
| 190 |
+
if end_idx_30s - start_idx_30s <= 1024:
|
| 191 |
+
continue
|
| 192 |
+
wraped_muq_embd_30s.append(
|
| 193 |
+
self.muq(
|
| 194 |
+
audio[start_idx_30s:end_idx_30s].unsqueeze(0),
|
| 195 |
+
output_hidden_states=True,
|
| 196 |
+
)["hidden_states"][10]
|
| 197 |
+
)
|
| 198 |
+
torch.cuda.empty_cache()
|
| 199 |
+
wraped_musicfm_embd_30s.append(
|
| 200 |
+
self.musicfm.get_predictions(
|
| 201 |
+
audio[start_idx_30s:end_idx_30s].unsqueeze(0)
|
| 202 |
+
)[1][10]
|
| 203 |
+
)
|
| 204 |
+
torch.cuda.empty_cache()
|
| 205 |
+
|
| 206 |
+
wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
|
| 207 |
+
wraped_musicfm_embd_30s = torch.concatenate(
|
| 208 |
+
wraped_musicfm_embd_30s, dim=1
|
| 209 |
+
)
|
| 210 |
+
all_embds = [
|
| 211 |
+
wraped_musicfm_embd_30s,
|
| 212 |
+
wraped_muq_embd_30s,
|
| 213 |
+
musicfm_embd_420s,
|
| 214 |
+
muq_embd_420s,
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
if len(all_embds) > 1:
|
| 218 |
+
embd_lens = [x.shape[1] for x in all_embds]
|
| 219 |
+
max_embd_len = max(embd_lens)
|
| 220 |
+
min_embd_len = min(embd_lens)
|
| 221 |
+
if abs(max_embd_len - min_embd_len) > 4:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
for idx in range(len(all_embds)):
|
| 227 |
+
all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
|
| 228 |
+
|
| 229 |
+
embd = torch.concatenate(all_embds, axis=-1)
|
| 230 |
+
|
| 231 |
+
dataset_label = DATASET_LABEL
|
| 232 |
+
dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
|
| 233 |
+
msa_info, chunk_logits = self.songformer.infer(
|
| 234 |
+
input_embeddings=embd,
|
| 235 |
+
dataset_ids=dataset_ids,
|
| 236 |
+
label_id_masks=torch.Tensor(
|
| 237 |
+
self.dataset_id2label_mask[
|
| 238 |
+
DATASET_LABEL_TO_DATASET_ID[dataset_label]
|
| 239 |
+
]
|
| 240 |
+
)
|
| 241 |
+
.to(device, dtype=bool)
|
| 242 |
+
.unsqueeze(0)
|
| 243 |
+
.unsqueeze(0),
|
| 244 |
+
with_logits=True,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
|
| 248 |
+
end_frame = start_frame + min(
|
| 249 |
+
math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
|
| 250 |
+
chunk_logits["boundary_logits"][0].shape[0],
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
logits["function_logits"][start_frame:end_frame, :] += (
|
| 254 |
+
chunk_logits["function_logits"][0].detach().cpu().numpy()
|
| 255 |
+
)
|
| 256 |
+
logits["boundary_logits"][start_frame:end_frame] = (
|
| 257 |
+
chunk_logits["boundary_logits"][0].detach().cpu().numpy()
|
| 258 |
+
)
|
| 259 |
+
logits_num["function_logits"][start_frame:end_frame, :] += 1
|
| 260 |
+
logits_num["boundary_logits"][start_frame:end_frame] += 1
|
| 261 |
+
lens += end_frame - start_frame
|
| 262 |
+
|
| 263 |
+
i += hop_size
|
| 264 |
+
logits["function_logits"] /= logits_num["function_logits"]
|
| 265 |
+
logits["boundary_logits"] /= logits_num["boundary_logits"]
|
| 266 |
+
|
| 267 |
+
logits["function_logits"] = torch.from_numpy(
|
| 268 |
+
logits["function_logits"][:lens]
|
| 269 |
+
).unsqueeze(0)
|
| 270 |
+
logits["boundary_logits"] = torch.from_numpy(
|
| 271 |
+
logits["boundary_logits"][:lens]
|
| 272 |
+
).unsqueeze(0)
|
| 273 |
+
|
| 274 |
+
msa_infer_output = postprocess_functional_structure(logits, self.config)
|
| 275 |
+
|
| 276 |
+
assert msa_infer_output[-1][-1] == "end"
|
| 277 |
+
if not self.config.no_rule_post_processing:
|
| 278 |
+
msa_infer_output = rule_post_processing(msa_infer_output)
|
| 279 |
+
msa_json = []
|
| 280 |
+
for idx in range(len(msa_infer_output) - 1):
|
| 281 |
+
msa_json.append(
|
| 282 |
+
{
|
| 283 |
+
"label": msa_infer_output[idx][1],
|
| 284 |
+
"start": msa_infer_output[idx][0],
|
| 285 |
+
"end": msa_infer_output[idx + 1][0],
|
| 286 |
+
}
|
| 287 |
+
)
|
| 288 |
+
return msa_json
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
|
| 292 |
+
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
| 293 |
+
|
| 294 |
+
# ---- begin: ignore muq ----
|
| 295 |
+
if key.startswith("muq."):
|
| 296 |
+
return key, False
|
| 297 |
+
# ---- end ---
|
| 298 |
+
|
| 299 |
+
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
|
| 300 |
+
# This rename is logged.
|
| 301 |
+
if key.endswith("LayerNorm.beta"):
|
| 302 |
+
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
|
| 303 |
+
if key.endswith("LayerNorm.gamma"):
|
| 304 |
+
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
|
| 305 |
+
|
| 306 |
+
# Rename weight norm parametrizations to match changes across torch versions.
|
| 307 |
+
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
|
| 308 |
+
# This rename is not logged.
|
| 309 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 310 |
+
if key.endswith("weight_g"):
|
| 311 |
+
return key.replace(
|
| 312 |
+
"weight_g", "parametrizations.weight.original0"
|
| 313 |
+
), True
|
| 314 |
+
if key.endswith("weight_v"):
|
| 315 |
+
return key.replace(
|
| 316 |
+
"weight_v", "parametrizations.weight.original1"
|
| 317 |
+
), True
|
| 318 |
+
else:
|
| 319 |
+
if key.endswith("parametrizations.weight.original0"):
|
| 320 |
+
return key.replace(
|
| 321 |
+
"parametrizations.weight.original0", "weight_g"
|
| 322 |
+
), True
|
| 323 |
+
if key.endswith("parametrizations.weight.original1"):
|
| 324 |
+
return key.replace(
|
| 325 |
+
"parametrizations.weight.original1", "weight_v"
|
| 326 |
+
), True
|
| 327 |
+
|
| 328 |
+
return key, False
|
msd_stats.json
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"spec_256_cnt": 14394344256,
|
| 3 |
+
"spec_256_mean": -23.34296658431829,
|
| 4 |
+
"spec_256_std": 26.189295587132637,
|
| 5 |
+
"spec_512_cnt": 28677104448,
|
| 6 |
+
"spec_512_mean": -21.31267396860235,
|
| 7 |
+
"spec_512_std": 26.52644536245769,
|
| 8 |
+
"spec_1024_cnt": 57242624832,
|
| 9 |
+
"spec_1024_mean": -18.852271129208273,
|
| 10 |
+
"spec_1024_std": 26.443154583585663,
|
| 11 |
+
"spec_2048_cnt": 114373665600,
|
| 12 |
+
"spec_2048_mean": -15.638743433896792,
|
| 13 |
+
"spec_2048_std": 26.115825961611545,
|
| 14 |
+
"spec_4096_cnt": 228635747136,
|
| 15 |
+
"spec_4096_mean": -11.715532502794836,
|
| 16 |
+
"spec_4096_std": 25.763972210234062,
|
| 17 |
+
"melspec_256_cnt": 14282760192,
|
| 18 |
+
"melspec_256_mean": -26.962600400166156,
|
| 19 |
+
"melspec_256_std": 36.13614100912126,
|
| 20 |
+
"melspec_512_cnt": 14282760192,
|
| 21 |
+
"melspec_512_mean": -9.108344167718862,
|
| 22 |
+
"melspec_512_std": 24.71910937988429,
|
| 23 |
+
"melspec_1024_cnt": 14282760192,
|
| 24 |
+
"melspec_1024_mean": 0.37302579246531126,
|
| 25 |
+
"melspec_1024_std": 18.684082325919388,
|
| 26 |
+
"melspec_2048_cnt": 14282760192,
|
| 27 |
+
"melspec_2048_mean": 6.768444971712967,
|
| 28 |
+
"melspec_2048_std": 18.417922652295623,
|
| 29 |
+
"melspec_4096_cnt": 14282760192,
|
| 30 |
+
"melspec_4096_mean": 13.617164614990036,
|
| 31 |
+
"melspec_4096_std": 18.08552130124525,
|
| 32 |
+
"cqt_cnt": 9373061376,
|
| 33 |
+
"cqt_mean": 0.46341379757927165,
|
| 34 |
+
"cqt_std": 0.9543998080910191,
|
| 35 |
+
"mfcc_256_cnt": 1339008768,
|
| 36 |
+
"mfcc_256_mean": -11.681755459447485,
|
| 37 |
+
"mfcc_256_std": 29.183186444668316,
|
| 38 |
+
"mfcc_512_cnt": 1339008768,
|
| 39 |
+
"mfcc_512_mean": -2.540581461792183,
|
| 40 |
+
"mfcc_512_std": 31.93752185832081,
|
| 41 |
+
"mfcc_1024_cnt": 1339008768,
|
| 42 |
+
"mfcc_1024_mean": 6.606636263169779,
|
| 43 |
+
"mfcc_1024_std": 34.151644801729624,
|
| 44 |
+
"mfcc_2048_cnt": 1339008768,
|
| 45 |
+
"mfcc_2048_mean": 5.281600844245184,
|
| 46 |
+
"mfcc_2048_std": 33.12784541220003,
|
| 47 |
+
"mfcc_4096_cnt": 1339008768,
|
| 48 |
+
"mfcc_4096_mean": 4.7616569480166095,
|
| 49 |
+
"mfcc_4096_std": 32.61458906894133,
|
| 50 |
+
"chromagram_256_cnt": 1339008768,
|
| 51 |
+
"chromagram_256_mean": 55.15596556703181,
|
| 52 |
+
"chromagram_256_std": 73.91858278719991,
|
| 53 |
+
"chromagram_512_cnt": 1339008768,
|
| 54 |
+
"chromagram_512_mean": 175.73092252759895,
|
| 55 |
+
"chromagram_512_std": 248.48485148525953,
|
| 56 |
+
"chromagram_1024_cnt": 1339008768,
|
| 57 |
+
"chromagram_1024_mean": 589.2947481634608,
|
| 58 |
+
"chromagram_1024_std": 913.857929063196,
|
| 59 |
+
"chromagram_2048_cnt": 1339008768,
|
| 60 |
+
"chromagram_2048_mean": 2062.286388327397,
|
| 61 |
+
"chromagram_2048_std": 3458.92657915397,
|
| 62 |
+
"chromagram_4096_cnt": 1339008768,
|
| 63 |
+
"chromagram_4096_mean": 7673.039107997085,
|
| 64 |
+
"chromagram_4096_std": 13009.883158267234
|
| 65 |
+
}
|
muq_config2.json
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"label_rate": 25,
|
| 3 |
+
"num_codebooks": 1,
|
| 4 |
+
"codebook_dim": 16,
|
| 5 |
+
"codebook_size": 8192,
|
| 6 |
+
"features": [
|
| 7 |
+
"melspec_2048"
|
| 8 |
+
],
|
| 9 |
+
"hop_length": 240,
|
| 10 |
+
"n_mels": 128,
|
| 11 |
+
"conv_dim": 512,
|
| 12 |
+
"encoder_dim": 1024,
|
| 13 |
+
"encoder_depth": 12,
|
| 14 |
+
"mask_hop": 0.4,
|
| 15 |
+
"mask_prob": 0.6,
|
| 16 |
+
"is_flash": false,
|
| 17 |
+
"stat": {
|
| 18 |
+
"melspec_2048_cnt": 14282760192,
|
| 19 |
+
"melspec_2048_mean": 6.768444971712967,
|
| 20 |
+
"melspec_2048_std": 18.417922652295623
|
| 21 |
+
},
|
| 22 |
+
"w2v2_config": {
|
| 23 |
+
"activation_dropout": 0.1,
|
| 24 |
+
"adapter_kernel_size": 3,
|
| 25 |
+
"adapter_stride": 2,
|
| 26 |
+
"add_adapter": false,
|
| 27 |
+
"apply_spec_augment": true,
|
| 28 |
+
"architectures": [
|
| 29 |
+
"Wav2Vec2ConformerForCTC"
|
| 30 |
+
],
|
| 31 |
+
"attention_dropout": 0.1,
|
| 32 |
+
"bos_token_id": 1,
|
| 33 |
+
"classifier_proj_size": 256,
|
| 34 |
+
"codevector_dim": 768,
|
| 35 |
+
"conformer_conv_dropout": 0.1,
|
| 36 |
+
"contrastive_logits_temperature": 0.1,
|
| 37 |
+
"conv_bias": true,
|
| 38 |
+
"conv_depthwise_kernel_size": 31,
|
| 39 |
+
"conv_dim": [
|
| 40 |
+
512,
|
| 41 |
+
512,
|
| 42 |
+
512,
|
| 43 |
+
512,
|
| 44 |
+
512,
|
| 45 |
+
512,
|
| 46 |
+
512
|
| 47 |
+
],
|
| 48 |
+
"conv_kernel": [
|
| 49 |
+
10,
|
| 50 |
+
3,
|
| 51 |
+
3,
|
| 52 |
+
3,
|
| 53 |
+
3,
|
| 54 |
+
2,
|
| 55 |
+
2
|
| 56 |
+
],
|
| 57 |
+
"conv_stride": [
|
| 58 |
+
5,
|
| 59 |
+
2,
|
| 60 |
+
2,
|
| 61 |
+
2,
|
| 62 |
+
2,
|
| 63 |
+
2,
|
| 64 |
+
2
|
| 65 |
+
],
|
| 66 |
+
"ctc_loss_reduction": "sum",
|
| 67 |
+
"ctc_zero_infinity": false,
|
| 68 |
+
"diversity_loss_weight": 0.1,
|
| 69 |
+
"do_stable_layer_norm": true,
|
| 70 |
+
"eos_token_id": 2,
|
| 71 |
+
"feat_extract_activation": "gelu",
|
| 72 |
+
"feat_extract_dropout": 0.0,
|
| 73 |
+
"feat_extract_norm": "layer",
|
| 74 |
+
"feat_proj_dropout": 0.1,
|
| 75 |
+
"feat_quantizer_dropout": 0.0,
|
| 76 |
+
"final_dropout": 0.1,
|
| 77 |
+
"gradient_checkpointing": false,
|
| 78 |
+
"hidden_act": "swish",
|
| 79 |
+
"hidden_dropout": 0.1,
|
| 80 |
+
"hidden_dropout_prob": 0.1,
|
| 81 |
+
"hidden_size": 1024,
|
| 82 |
+
"initializer_range": 0.02,
|
| 83 |
+
"intermediate_size": 4096,
|
| 84 |
+
"layer_norm_eps": 1e-05,
|
| 85 |
+
"layerdrop": 0.0,
|
| 86 |
+
"mask_feature_length": 10,
|
| 87 |
+
"mask_feature_min_masks": 0,
|
| 88 |
+
"mask_feature_prob": 0.0,
|
| 89 |
+
"mask_time_length": 10,
|
| 90 |
+
"mask_time_min_masks": 2,
|
| 91 |
+
"mask_time_prob": 0.05,
|
| 92 |
+
"max_source_positions": 5000,
|
| 93 |
+
"model_type": "wav2vec2-conformer",
|
| 94 |
+
"num_adapter_layers": 3,
|
| 95 |
+
"num_attention_heads": 16,
|
| 96 |
+
"num_codevector_groups": 2,
|
| 97 |
+
"num_codevectors_per_group": 320,
|
| 98 |
+
"num_conv_pos_embedding_groups": 16,
|
| 99 |
+
"num_conv_pos_embeddings": 128,
|
| 100 |
+
"num_feat_extract_layers": 7,
|
| 101 |
+
"num_hidden_layers": 24,
|
| 102 |
+
"num_negatives": 100,
|
| 103 |
+
"output_hidden_size": 1024,
|
| 104 |
+
"pad_token_id": 0,
|
| 105 |
+
"position_embeddings_type": "rotary",
|
| 106 |
+
"proj_codevector_dim": 768,
|
| 107 |
+
"rotary_embedding_base": 10000,
|
| 108 |
+
"tdnn_dilation": [
|
| 109 |
+
1,
|
| 110 |
+
2,
|
| 111 |
+
3,
|
| 112 |
+
1,
|
| 113 |
+
1
|
| 114 |
+
],
|
| 115 |
+
"tdnn_dim": [
|
| 116 |
+
512,
|
| 117 |
+
512,
|
| 118 |
+
512,
|
| 119 |
+
512,
|
| 120 |
+
1500
|
| 121 |
+
],
|
| 122 |
+
"tdnn_kernel": [
|
| 123 |
+
5,
|
| 124 |
+
3,
|
| 125 |
+
3,
|
| 126 |
+
1,
|
| 127 |
+
1
|
| 128 |
+
],
|
| 129 |
+
"torch_dtype": "float32",
|
| 130 |
+
"transformers_version": "4.19.0.dev0",
|
| 131 |
+
"use_weighted_layer_sum": false,
|
| 132 |
+
"vocab_size": 32,
|
| 133 |
+
"xvector_output_dim": 512
|
| 134 |
+
},
|
| 135 |
+
"use_rvq_target": true,
|
| 136 |
+
"use_vq_target": false,
|
| 137 |
+
"use_encodec_target": false,
|
| 138 |
+
"rvq_ckpt_path": null,
|
| 139 |
+
"recon_loss_ratio": null,
|
| 140 |
+
"resume_checkpoint": null,
|
| 141 |
+
"rvq_n_codebooks": 8,
|
| 142 |
+
"rvq_multi_layer_num": 1
|
| 143 |
+
}
|
musicfm/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mac
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# cache
|
| 5 |
+
*.pyc
|
| 6 |
+
|
| 7 |
+
# data
|
| 8 |
+
*.json
|
| 9 |
+
*.pt
|
| 10 |
+
|
musicfm/LICENSE
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Dual Licensing Information
|
| 2 |
+
-------------------------
|
| 3 |
+
|
| 4 |
+
This software is dual-licensed under both the MIT License and the Apache License, Version 2.0.
|
| 5 |
+
|
| 6 |
+
- The file `modules/flash_conformer.py` is distributed under the terms of the Apache License, Version 2.0.
|
| 7 |
+
- All other files and modules in this software are distributed under the terms of the MIT License.
|
| 8 |
+
|
| 9 |
+
### MIT License
|
| 10 |
+
|
| 11 |
+
Copyright 2023 ByteDance Inc.
|
| 12 |
+
|
| 13 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
### Apache License, Version 2.0
|
| 21 |
+
|
| 22 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
| 23 |
+
|
| 24 |
+
Apache License
|
| 25 |
+
Version 2.0, January 2004
|
| 26 |
+
http://www.apache.org/licenses/
|
| 27 |
+
|
| 28 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 29 |
+
|
| 30 |
+
1. Definitions.
|
| 31 |
+
|
| 32 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 33 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 34 |
+
|
| 35 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 36 |
+
the copyright owner that is granting the License.
|
| 37 |
+
|
| 38 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 39 |
+
other entities that control, are controlled by, or are under common
|
| 40 |
+
control with that entity. For the purposes of this definition,
|
| 41 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 42 |
+
direction or management of such entity, whether by contract or
|
| 43 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 44 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 45 |
+
|
| 46 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 47 |
+
exercising permissions granted by this License.
|
| 48 |
+
|
| 49 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 50 |
+
including but not limited to software source code, documentation
|
| 51 |
+
source, and configuration files.
|
| 52 |
+
|
| 53 |
+
"Object" form shall mean any form resulting from mechanical
|
| 54 |
+
transformation or translation of a Source form, including but
|
| 55 |
+
not limited to compiled object code, generated documentation,
|
| 56 |
+
and conversions to other media types.
|
| 57 |
+
|
| 58 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 59 |
+
Object form, made available under the License, as indicated by a
|
| 60 |
+
copyright notice that is included in or attached to the work
|
| 61 |
+
(an example is provided in the Appendix below).
|
| 62 |
+
|
| 63 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 64 |
+
form, that is based on (or derived from) the Work and for which the
|
| 65 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 66 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 67 |
+
of this License, Derivative Works shall not include works that remain
|
| 68 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 69 |
+
the Work and Derivative Works thereof.
|
| 70 |
+
|
| 71 |
+
"Contribution" shall mean any work of authorship, including
|
| 72 |
+
the original version of the Work and any modifications or additions
|
| 73 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 74 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 75 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 76 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 77 |
+
means any form of electronic, verbal, or written communication sent
|
| 78 |
+
to the Licensor or its representatives, including but not limited to
|
| 79 |
+
communication on electronic mailing lists, source code control systems,
|
| 80 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 81 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 82 |
+
excluding communication that is conspicuously marked or otherwise
|
| 83 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 84 |
+
|
| 85 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 86 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 87 |
+
subsequently incorporated within the Work.
|
| 88 |
+
|
| 89 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 90 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 91 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 92 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 93 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 94 |
+
Work and such Derivative Works in Source or Object form.
|
| 95 |
+
|
| 96 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 97 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 98 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 99 |
+
(except as stated in this section) patent license to make, have made,
|
| 100 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 101 |
+
where such license applies only to those patent claims licensable
|
| 102 |
+
by such Contributor that are necessarily infringed by their
|
| 103 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 104 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 105 |
+
institute patent litigation against any entity (including a
|
| 106 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 107 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 108 |
+
or contributory patent infringement, then any patent licenses
|
| 109 |
+
granted to You under this License for that Work shall terminate
|
| 110 |
+
as of the date such litigation is filed.
|
| 111 |
+
|
| 112 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 113 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 114 |
+
modifications, and in Source or Object form, provided that You
|
| 115 |
+
meet the following conditions:
|
| 116 |
+
|
| 117 |
+
(a) You must give any other recipients of the Work or
|
| 118 |
+
Derivative Works a copy of this License; and
|
| 119 |
+
|
| 120 |
+
(b) You must cause any modified files to carry prominent notices
|
| 121 |
+
stating that You changed the files; and
|
| 122 |
+
|
| 123 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 124 |
+
that You distribute, all copyright, patent, trademark, and
|
| 125 |
+
attribution notices from the Source form of the Work,
|
| 126 |
+
excluding those notices that do not pertain to any part of
|
| 127 |
+
the Derivative Works; and
|
| 128 |
+
|
| 129 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 130 |
+
distribution, then any Derivative Works that You distribute must
|
| 131 |
+
include a readable copy of the attribution notices contained
|
| 132 |
+
within such NOTICE file, excluding those notices that do not
|
| 133 |
+
pertain to any part of the Derivative Works, in at least one
|
| 134 |
+
of the following places: within a NOTICE text file distributed
|
| 135 |
+
as part of the Derivative Works; within the Source form or
|
| 136 |
+
documentation, if provided along with the Derivative Works; or,
|
| 137 |
+
within a display generated by the Derivative Works, if and
|
| 138 |
+
wherever such third-party notices normally appear. The contents
|
| 139 |
+
of the NOTICE file are for informational purposes only and
|
| 140 |
+
do not modify the License. You may add Your own attribution
|
| 141 |
+
notices within Derivative Works that You distribute, alongside
|
| 142 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 143 |
+
that such additional attribution notices cannot be construed
|
| 144 |
+
as modifying the License.
|
| 145 |
+
|
| 146 |
+
You may add Your own copyright statement to Your modifications and
|
| 147 |
+
may provide additional or different license terms and conditions
|
| 148 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 149 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 150 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 151 |
+
the conditions stated in this License.
|
| 152 |
+
|
| 153 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 154 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 155 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 156 |
+
this License, without any additional terms or conditions.
|
| 157 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 158 |
+
the terms of any separate license agreement you may have executed
|
| 159 |
+
with Licensor regarding such Contributions.
|
| 160 |
+
|
| 161 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 162 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 163 |
+
except as required for reasonable and customary use in describing the
|
| 164 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 165 |
+
|
| 166 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 167 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 168 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 169 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 170 |
+
implied, including, without limitation, any warranties or conditions
|
| 171 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 172 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 173 |
+
appropriateness of using or redistributing the Work and assume any
|
| 174 |
+
risks associated with Your exercise of permissions under this License.
|
| 175 |
+
|
| 176 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 177 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 178 |
+
unless required by applicable law (such as deliberate and grossly
|
| 179 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 180 |
+
liable to You for damages, including any direct, indirect, special,
|
| 181 |
+
incidental, or consequential damages of any character arising as a
|
| 182 |
+
result of this License or out of the use or inability to use the
|
| 183 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 184 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 185 |
+
other commercial damages or losses), even if such Contributor
|
| 186 |
+
has been advised of the possibility of such damages.
|
| 187 |
+
|
| 188 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 189 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 190 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 191 |
+
or other liability obligations and/or rights consistent with this
|
| 192 |
+
License. However, in accepting such obligations, You may act only
|
| 193 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 194 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 195 |
+
defend, and hold each Contributor harmless for any liability
|
| 196 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 197 |
+
of your accepting any such warranty or additional liability.
|
| 198 |
+
|
| 199 |
+
END OF TERMS AND CONDITIONS
|
| 200 |
+
|
| 201 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 202 |
+
|
| 203 |
+
To apply the Apache License to your work, attach the following
|
| 204 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 205 |
+
replaced with your own identifying information. (Don't include
|
| 206 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 207 |
+
comment syntax for the file format. We also recommend that a
|
| 208 |
+
file or class name and description of purpose be included on the
|
| 209 |
+
same "printed page" as the copyright notice for easier
|
| 210 |
+
identification within third-party archives.
|
| 211 |
+
|
| 212 |
+
Copyright [yyyy] [name of copyright owner]
|
| 213 |
+
|
| 214 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 215 |
+
you may not use this file except in compliance with the License.
|
| 216 |
+
You may obtain a copy of the License at
|
| 217 |
+
|
| 218 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 219 |
+
|
| 220 |
+
Unless required by applicable law or agreed to in writing, software
|
| 221 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 222 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 223 |
+
See the License for the specific language governing permissions and
|
| 224 |
+
limitations under the License.
|
musicfm/README.md
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MusicFM 🤖
|
| 2 |
+
[](https://opensource.org/licenses/MIT)
|
| 3 |
+
[](https://www.apache.org/licenses/LICENSE-2.0.html)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
**A Foundation Model for Music Informatics**, ICASSP 2024 [[paper](https://arxiv.org/abs/2311.03318)]
|
| 7 |
+
|
| 8 |
+
-- Minz Won, Yun-Ning Hung, and Duc Le
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Quick start
|
| 12 |
+
### Download models
|
| 13 |
+
|
| 14 |
+
**MusicFM-FMA**
|
| 15 |
+
|
| 16 |
+
- Pretrained using [FMA-large](https://github.com/mdeff/fma) data
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/fma_stats.json
|
| 20 |
+
wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_fma.pt
|
| 21 |
+
```
|
| 22 |
+
⚠️ The model checkpoint prior to Feb 13, 2024, was incorrect. Please ensure to re-download these files if you've been using previous versions.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
**MusicFM-MSD**
|
| 26 |
+
|
| 27 |
+
- Pretrained with the entire [Million Song Dataset](http://millionsongdataset.com/)
|
| 28 |
+
- This version performs better than the FMA version
|
| 29 |
+
- This version is not introduced in the paper
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json
|
| 33 |
+
wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Get embeddings
|
| 37 |
+
```
|
| 38 |
+
HOME_PATH = "/home/dev" # path where you cloned musicfm
|
| 39 |
+
|
| 40 |
+
import os
|
| 41 |
+
import sys
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
sys.path.append(HOME_PATH)
|
| 45 |
+
from musicfm.model.musicfm_25hz import MusicFM25Hz
|
| 46 |
+
|
| 47 |
+
# dummy audio (30 seconds, 24kHz)
|
| 48 |
+
wav = (torch.rand(4, 24000 * 30) - 0.5) * 2
|
| 49 |
+
|
| 50 |
+
# load MusicFM
|
| 51 |
+
musicfm = MusicFM25Hz(
|
| 52 |
+
is_flash=False,
|
| 53 |
+
stat_path=os.path.join(HOME_PATH, "musicfm", "data", "msd_stats.json"),
|
| 54 |
+
model_path=os.path.join(HOME_PATH, "musicfm", "data", "pretrained_msd.pt"),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# to GPUs
|
| 58 |
+
wav = wav.cuda()
|
| 59 |
+
musicfm = musicfm.cuda()
|
| 60 |
+
|
| 61 |
+
# get embeddings
|
| 62 |
+
musicfm.eval()
|
| 63 |
+
emb = musicfm.get_latent(wav, layer_ix=7)
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### Mixed precision and Flash attention
|
| 67 |
+
Suffering from memory issues? [Mixed precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) and [Flash attention](https://arxiv.org/abs/2205.14135) will be good friends of yours!
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
# dummy audio (30 seconds, 24kHz)
|
| 71 |
+
wav = (torch.rand(4, 24000 * 30) - 0.5) * 2
|
| 72 |
+
|
| 73 |
+
# load MusicFM
|
| 74 |
+
musicfm = MusicFM25Hz(is_flash=True)
|
| 75 |
+
|
| 76 |
+
# to GPUs
|
| 77 |
+
wav = wav.cuda().half()
|
| 78 |
+
musicfm = musicfm.cuda().half()
|
| 79 |
+
|
| 80 |
+
# get embeddings
|
| 81 |
+
musicfm.eval()
|
| 82 |
+
emb = musicfm.get_latent(wav, layer_ix=7)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
However, I highly recommend using `float32` for better performance in specific downstream tasks, such as beat tracking.
|
| 86 |
+
|
| 87 |
+
### Usage in downstream tasks
|
| 88 |
+
The pretrained model operates at a 25Hz frame rate, but our downstream tasks demand varying temporal resolutions. To address this, we either summarize the sequence through global average pooling or adjust the temporal resolution using adaptive average pooling.
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
from torch import nn
|
| 92 |
+
|
| 93 |
+
# Sequence-level representation
|
| 94 |
+
seq_emb = emb.mean(-1) # (batch, time, channel) -> (batch, channel)
|
| 95 |
+
|
| 96 |
+
# Frame-level representation
|
| 97 |
+
"""
|
| 98 |
+
n_frame = desired_temporal_resolution * sequence_length_in_sec
|
| 99 |
+
300 frames = 10Hz * 30s in this example
|
| 100 |
+
As a result, the sequence length becomes from 750 (25Hz * 30s) to 300
|
| 101 |
+
"""
|
| 102 |
+
n_frame = 300
|
| 103 |
+
token_emb = nn.AdaptiveAvgPool1d(n_frame)(emb) # (batch, time, channel) -> (batch, time', channel)
|
| 104 |
+
```
|
| 105 |
+
We share the details of our downstream evaluation as follows. The selection of input lengths and temporal resolutions is based on our prior experience with each task.
|
| 106 |
+
|
| 107 |
+
| | Beat | Chord | Structure | Key | Tagging |
|
| 108 |
+
| :--------: | :--------: | :--------: | :--------: | :--------: | :--------: |
|
| 109 |
+
| Input length | 6s | 12s | 24s | 12s | 29.1s |
|
| 110 |
+
| Temporal resolution | 50Hz | 16Hz | 8Hz | 0.5Hz | - |
|
| 111 |
+
| n_frame | 300 | 192 | 192 | 6 | 1 |
|
| 112 |
+
|
| 113 |
+
### Fine-tuning
|
| 114 |
+
You can expect better performance in downstream tasks by fine-tuning the foundation model. In this scenario, employ `musicfm.train()` and extract the final embeddings by setting `layer_ix=12`. However, when optimizing the model with the same learning rate, there's a risk of [catastrophic forgetting](https://en.wikipedia.org/wiki/Catastrophic_interference). To mitigate this issue, we utilized a learning rate of 1e-5 for the foundation model and 1e-4 for the probing layers.
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
## Results
|
| 119 |
+
|
| 120 |
+
<img src="figs/Table1.png" width="800">
|
| 121 |
+
|
| 122 |
+
\* FM1 is pretrained [MERT](https://arxiv.org/abs/2306.00107).
|
| 123 |
+
|
| 124 |
+
\*\*FM8 mirrors the [BEST-RQ](https://arxiv.org/abs/2202.01855) but with the distinction that it was trained using music data.
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
- Random tokenization generalizes well to music data.
|
| 128 |
+
|
| 129 |
+
- Frame-level classification offers a more comprehensive understanding of foundation models. While FM4 excels in music tagging, its performance in structural analysis is subpar.
|
| 130 |
+
|
| 131 |
+
- Input length used during training is critical for capturing
|
| 132 |
+
long-term contexts. Check 5s models (FM1, FM2, and FM4) and a 30s model (FM5) in downbeat tracking and structure analysis.
|
| 133 |
+
|
| 134 |
+
- Temporal resolution has less impact in our experimental setup. See FM5, FM6, and FM7.
|
| 135 |
+
|
| 136 |
+
- Model architecture makes a significant difference. Conformer (FM5) consistently outperformed BERT encoder (FM3) for across all downstream tasks.
|
| 137 |
+
|
| 138 |
+
- The influence of model size was relatively minimal (FM7 and FM8). However, we observed that FM8's performance continued to improve, which is typically indicative of underfitting. All models were trained for two weeks to ensure a fair comparison.
|
| 139 |
+
|
| 140 |
+
- Data is undeniably crucial, as in any data-driven approach. Please compare FM7 and FM9.
|
| 141 |
+
|
| 142 |
+
- Fine-tuning the foundation model further enhances downstream performance. However, we did observe a performance
|
| 143 |
+
drop in the tagging task, primarily attributed to overfitting.
|
| 144 |
+
|
| 145 |
+
## Masked token modeling
|
| 146 |
+
<img src="figs/Fig1.png" width="300">
|
| 147 |
+
|
| 148 |
+
MusicFM follows the training scheme of [BEST-RQ](https://arxiv.org/abs/2202.01855). Input audio is masked with noise, and the model predicts the masked representation. Target tokens are generated by random projection and a random codebook. Both the projection layer and codebook are **randomly initialized** and remain **non-trainable**. Isn't it fascinating?
|
| 149 |
+
|
| 150 |
+
Note that input normalization is exceptionally crucial, considering the usage of random projection. You can check the details [here](https://github.com/minzwon/musicfm/blob/d5d0f313add9f3c32c41f95521760b1a136809ed/model/musicfm_25hz.py#L148).
|
| 151 |
+
|
| 152 |
+
## Limitations
|
| 153 |
+
- Self-supervised foundation models in music, such as [JukeMIR](https://arxiv.org/abs/2107.05677), [MERT](https://arxiv.org/abs/2306.00107), and [MusicFM](https://arxiv.org/abs/2311.03318), consistently report relatively low performance in key detection. While fine-tuning the model can help bridge the performance gap, the foundation model itself does not appear to learn musical keys inherently. Further investigation is required to develop more advanced music foundation models.
|
| 154 |
+
|
| 155 |
+
- We share our model trained with the [FMA Dataset](https://github.com/mdeff/fma), which comprises 8k hours of Creative Common-licensed audio. While using a larger dataset (160k hours) can enhance performance, we've chosen to release the model trained on FMA to avoid potential licensing complications.
|
| 156 |
+
|
| 157 |
+
- Fine-tuned models for downstream tasks are not made publicly available as they are primarily used for evaluation purposes. It is expected that carefully designed backends beyond simple probing layers will improve downstream performance. I look forward to the contributions of other researchers with more expertise in each specific task.
|
| 158 |
+
|
| 159 |
+
- The downstream evaluation pipeline is not provided in this repository. Nonetheless, I believe creating a comprehensive evaluation pipeline is essential to expedite progress in music informatics research. I'm very open to discussing it together.
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
## Acknowledgement
|
| 163 |
+
We acknowledge and extend our sincere gratitude to Ju-Chiang Wang for his valuable contributions to data refinement and providing a crucial codebase for our downstream evaluation.
|
| 164 |
+
|
| 165 |
+
## Citation
|
| 166 |
+
```
|
| 167 |
+
@article{won2023musicfm,
|
| 168 |
+
title={A Foundation Model for Music Informatics},
|
| 169 |
+
author = {Won, Minz and Hung, Yun-Ning and Le, Duc},
|
| 170 |
+
journal={arXiv preprint arXiv:2311.03318},
|
| 171 |
+
year={2023}
|
| 172 |
+
}
|
| 173 |
+
```
|
musicfm/data/.gitkeep
ADDED
|
File without changes
|
musicfm/figs/Fig1.png
ADDED
|
Git LFS Details
|
musicfm/figs/Table1.png
ADDED
|
Git LFS Details
|
musicfm/model/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
musicfm/model/musicfm_25hz.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2023 ByteDance Inc.
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 8 |
+
#
|
| 9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 10 |
+
#
|
| 11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
| 14 |
+
# IN THE SOFTWARE.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import random
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
|
| 22 |
+
from musicfm.modules.random_quantizer import RandomProjectionQuantizer
|
| 23 |
+
from musicfm.modules.features import MelSTFT
|
| 24 |
+
from musicfm.modules.conv import Conv2dSubsampling
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MusicFM25Hz(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
MusicFM
|
| 30 |
+
|
| 31 |
+
Input: 128-band mel spectrogram
|
| 32 |
+
Frontend: 2-layer Residual convolution
|
| 33 |
+
Backend: 12-layer Conformer
|
| 34 |
+
Quantizer: a codebook for mel spectrogram
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
num_codebooks=1,
|
| 40 |
+
codebook_dim=16,
|
| 41 |
+
codebook_size=4096,
|
| 42 |
+
features=["melspec_2048"],
|
| 43 |
+
hop_length=240,
|
| 44 |
+
n_mels=128,
|
| 45 |
+
conv_dim=512,
|
| 46 |
+
encoder_dim=1024,
|
| 47 |
+
encoder_depth=12,
|
| 48 |
+
mask_hop=0.4,
|
| 49 |
+
mask_prob=0.6,
|
| 50 |
+
is_flash=False,
|
| 51 |
+
stat_path="./data/fma_stats.json",
|
| 52 |
+
# model_path="./data/pretrained_fma.pt",
|
| 53 |
+
):
|
| 54 |
+
super(MusicFM25Hz, self).__init__()
|
| 55 |
+
|
| 56 |
+
# global variables
|
| 57 |
+
self.hop_length = hop_length
|
| 58 |
+
self.mask_hop = mask_hop
|
| 59 |
+
self.mask_prob = mask_prob
|
| 60 |
+
self.num_codebooks = num_codebooks
|
| 61 |
+
self.codebook_size = codebook_size
|
| 62 |
+
self.features = features
|
| 63 |
+
|
| 64 |
+
# load feature mean / std stats
|
| 65 |
+
with open(stat_path, "r") as f:
|
| 66 |
+
self.stat = json.load(f)
|
| 67 |
+
|
| 68 |
+
# feature extractor
|
| 69 |
+
self.preprocessor_melspec_2048 = MelSTFT(
|
| 70 |
+
n_fft=2048, hop_length=hop_length, is_db=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# random quantizer
|
| 74 |
+
seed = 142
|
| 75 |
+
for feature in self.features:
|
| 76 |
+
for i in range(num_codebooks):
|
| 77 |
+
setattr(
|
| 78 |
+
self,
|
| 79 |
+
f"quantizer_{feature}_{i}",
|
| 80 |
+
RandomProjectionQuantizer(
|
| 81 |
+
n_mels * 4, codebook_dim, codebook_size, seed=seed + i
|
| 82 |
+
),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# two residual convolution layers + one projection layer
|
| 86 |
+
self.conv = Conv2dSubsampling(
|
| 87 |
+
1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Conformer
|
| 91 |
+
if is_flash:
|
| 92 |
+
from modules.flash_conformer import (
|
| 93 |
+
Wav2Vec2ConformerEncoder,
|
| 94 |
+
Wav2Vec2ConformerConfig,
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 98 |
+
Wav2Vec2ConformerEncoder,
|
| 99 |
+
Wav2Vec2ConformerConfig,
|
| 100 |
+
)
|
| 101 |
+
config = Wav2Vec2ConformerConfig.from_pretrained(
|
| 102 |
+
"facebook/wav2vec2-conformer-rope-large-960h-ft"
|
| 103 |
+
)
|
| 104 |
+
config.num_hidden_layers = encoder_depth
|
| 105 |
+
config.hidden_size = encoder_dim
|
| 106 |
+
|
| 107 |
+
self.conformer = Wav2Vec2ConformerEncoder(config)
|
| 108 |
+
|
| 109 |
+
# projection
|
| 110 |
+
self.linear = nn.Linear(encoder_dim, codebook_size)
|
| 111 |
+
|
| 112 |
+
# loss function
|
| 113 |
+
self.loss = nn.CrossEntropyLoss()
|
| 114 |
+
|
| 115 |
+
# cls token (used for sequence classification)
|
| 116 |
+
random.seed(seed)
|
| 117 |
+
self.cls_token = nn.Parameter(torch.randn(encoder_dim))
|
| 118 |
+
|
| 119 |
+
# load model
|
| 120 |
+
# if model_path:
|
| 121 |
+
# S = torch.load(model_path)["state_dict"]
|
| 122 |
+
# SS = {k[6:]: v for k, v in S.items()}
|
| 123 |
+
# self.load_state_dict(SS, strict=True)
|
| 124 |
+
|
| 125 |
+
def masking(self, x):
|
| 126 |
+
"""random masking of 400ms with given probability"""
|
| 127 |
+
mx = x.clone()
|
| 128 |
+
b, t = mx.shape
|
| 129 |
+
len_masking_raw = int(24000 * self.mask_hop)
|
| 130 |
+
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
|
| 131 |
+
|
| 132 |
+
# get random mask indices
|
| 133 |
+
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
|
| 134 |
+
time_domain_masked_indices = torch.nonzero(
|
| 135 |
+
start_indices.repeat_interleave(len_masking_raw, dim=1)
|
| 136 |
+
)
|
| 137 |
+
token_domain_masked_indices = torch.nonzero(
|
| 138 |
+
start_indices.repeat_interleave(len_masking_token, dim=1)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# mask with random values
|
| 142 |
+
masking_noise = (
|
| 143 |
+
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
|
| 144 |
+
) # 0 mean 0.1 std
|
| 145 |
+
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
|
| 146 |
+
|
| 147 |
+
return mx, token_domain_masked_indices
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def preprocessing(self, x, features):
|
| 151 |
+
"""extract classic audio features"""
|
| 152 |
+
# check precision
|
| 153 |
+
if x.dtype == torch.float16:
|
| 154 |
+
precision = 16
|
| 155 |
+
else:
|
| 156 |
+
precision = 32
|
| 157 |
+
|
| 158 |
+
out = {}
|
| 159 |
+
for key in features:
|
| 160 |
+
layer = getattr(self, "preprocessor_%s" % key)
|
| 161 |
+
out[key] = layer.float()(x.float())[..., :-1]
|
| 162 |
+
if precision == 16:
|
| 163 |
+
out[key] = out[key].half()
|
| 164 |
+
return out
|
| 165 |
+
|
| 166 |
+
def encoder(self, x):
|
| 167 |
+
"""2-layer conv + w2v-conformer"""
|
| 168 |
+
x = self.conv(x)
|
| 169 |
+
out = self.conformer(x, output_hidden_states=True)
|
| 170 |
+
hidden_emb = out["hidden_states"]
|
| 171 |
+
last_emb = out["last_hidden_state"]
|
| 172 |
+
logits = self.linear(last_emb)
|
| 173 |
+
logits = {
|
| 174 |
+
key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
|
| 175 |
+
for i, key in enumerate(self.features)
|
| 176 |
+
}
|
| 177 |
+
return logits, hidden_emb
|
| 178 |
+
|
| 179 |
+
@torch.no_grad()
|
| 180 |
+
def normalize(self, x):
|
| 181 |
+
"""normalize the input audio to have zero mean unit variance"""
|
| 182 |
+
for key in x.keys():
|
| 183 |
+
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
@torch.no_grad()
|
| 187 |
+
def rearrange(self, x):
|
| 188 |
+
"""rearrange the batch to flatten every 4 steps"""
|
| 189 |
+
for key in x.keys():
|
| 190 |
+
if key == "chromagram":
|
| 191 |
+
x[key] = rearrange(x[key], "b f t -> b t f")
|
| 192 |
+
else:
|
| 193 |
+
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def tokenize(self, x):
|
| 198 |
+
out = {}
|
| 199 |
+
for key in x.keys():
|
| 200 |
+
layer = getattr(self, "quantizer_%s" % key)
|
| 201 |
+
out[key] = layer(x[key])
|
| 202 |
+
return out
|
| 203 |
+
|
| 204 |
+
def get_targets(self, x):
|
| 205 |
+
x = self.preprocessing(x, features=self.features)
|
| 206 |
+
x = self.normalize(x)
|
| 207 |
+
x = self.rearrange(x)
|
| 208 |
+
target_tokens = self.tokenize(x)
|
| 209 |
+
return target_tokens
|
| 210 |
+
|
| 211 |
+
def get_predictions(self, x):
|
| 212 |
+
# preprocessing
|
| 213 |
+
x = self.preprocessing(x, features=["melspec_2048"])
|
| 214 |
+
x = self.normalize(x)
|
| 215 |
+
|
| 216 |
+
# encoding
|
| 217 |
+
logits, hidden_emb = self.encoder(x["melspec_2048"])
|
| 218 |
+
|
| 219 |
+
return logits, hidden_emb
|
| 220 |
+
|
| 221 |
+
def get_latent(self, x, layer_ix=12):
|
| 222 |
+
_, hidden_states = self.get_predictions(x)
|
| 223 |
+
emb = hidden_states[layer_ix]
|
| 224 |
+
return emb
|
| 225 |
+
|
| 226 |
+
def get_loss(self, logits, target_tokens, masked_indices):
|
| 227 |
+
losses = {}
|
| 228 |
+
accuracies = {}
|
| 229 |
+
for key in logits.keys():
|
| 230 |
+
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 231 |
+
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 232 |
+
losses[key] = self.loss(masked_logits, masked_tokens)
|
| 233 |
+
accuracies[key] = (
|
| 234 |
+
torch.sum(masked_logits.argmax(-1) == masked_tokens)
|
| 235 |
+
/ masked_tokens.numel()
|
| 236 |
+
)
|
| 237 |
+
return losses, accuracies
|
| 238 |
+
|
| 239 |
+
def forward(self, x):
|
| 240 |
+
# get target feature tokens
|
| 241 |
+
target_tokens = self.get_targets(x)
|
| 242 |
+
|
| 243 |
+
# masking
|
| 244 |
+
x, masked_indices = self.masking(x)
|
| 245 |
+
|
| 246 |
+
# forward
|
| 247 |
+
logits, hidden_emb = self.get_predictions(x)
|
| 248 |
+
|
| 249 |
+
# get loss
|
| 250 |
+
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
|
| 251 |
+
|
| 252 |
+
return logits, hidden_emb, losses, accuracies
|
musicfm/modules/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
musicfm/modules/conv.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2023 ByteDance Inc.
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 8 |
+
#
|
| 9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 10 |
+
#
|
| 11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
| 14 |
+
# IN THE SOFTWARE.
|
| 15 |
+
|
| 16 |
+
from torch import nn
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Res2dModule(nn.Module):
|
| 21 |
+
def __init__(self, idim, odim, stride=(2, 2)):
|
| 22 |
+
super(Res2dModule, self).__init__()
|
| 23 |
+
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 24 |
+
self.bn1 = nn.BatchNorm2d(odim)
|
| 25 |
+
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
|
| 26 |
+
self.bn2 = nn.BatchNorm2d(odim)
|
| 27 |
+
self.relu = nn.ReLU()
|
| 28 |
+
|
| 29 |
+
# residual
|
| 30 |
+
self.diff = False
|
| 31 |
+
if (idim != odim) or (stride[0] > 1):
|
| 32 |
+
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 33 |
+
self.bn3 = nn.BatchNorm2d(odim)
|
| 34 |
+
self.diff = True
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
|
| 38 |
+
if self.diff:
|
| 39 |
+
x = self.bn3(self.conv3(x))
|
| 40 |
+
out = x + out
|
| 41 |
+
out = self.relu(out)
|
| 42 |
+
return out
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Conv2dSubsampling(nn.Module):
|
| 46 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
idim (int): Input dimension.
|
| 50 |
+
hdim (int): Hidden dimension.
|
| 51 |
+
odim (int): Output dimension.
|
| 52 |
+
strides (list): Sizes of strides.
|
| 53 |
+
n_bands (int): Number of frequency bands.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
|
| 57 |
+
"""Construct an Conv2dSubsampling object."""
|
| 58 |
+
super(Conv2dSubsampling, self).__init__()
|
| 59 |
+
|
| 60 |
+
self.conv = nn.Sequential(
|
| 61 |
+
Res2dModule(idim, hdim, (2, strides[0])),
|
| 62 |
+
Res2dModule(hdim, hdim, (2, strides[1])),
|
| 63 |
+
)
|
| 64 |
+
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
"""Subsample x.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
x (torch.Tensor): Input tensor (#batch, idim, time).
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 74 |
+
where time' = time // 4.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if x.dim() == 3:
|
| 78 |
+
x = x.unsqueeze(1) # (b, c, f, t)
|
| 79 |
+
x = self.conv(x)
|
| 80 |
+
x = rearrange(x, "b c f t -> b t (c f)")
|
| 81 |
+
x = self.linear(x)
|
| 82 |
+
return x
|
musicfm/modules/features.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2023 ByteDance Inc.
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 8 |
+
#
|
| 9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 10 |
+
#
|
| 11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
| 14 |
+
# IN THE SOFTWARE.
|
| 15 |
+
|
| 16 |
+
import torchaudio
|
| 17 |
+
from torch import nn
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MelSTFT(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
sample_rate=24000,
|
| 24 |
+
n_fft=2048,
|
| 25 |
+
hop_length=240,
|
| 26 |
+
n_mels=128,
|
| 27 |
+
is_db=False,
|
| 28 |
+
):
|
| 29 |
+
super(MelSTFT, self).__init__()
|
| 30 |
+
|
| 31 |
+
# spectrogram
|
| 32 |
+
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 33 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# amplitude to decibel
|
| 37 |
+
self.is_db = is_db
|
| 38 |
+
if is_db:
|
| 39 |
+
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 40 |
+
|
| 41 |
+
def forward(self, waveform):
|
| 42 |
+
if self.is_db:
|
| 43 |
+
return self.amplitude_to_db(self.mel_stft(waveform))
|
| 44 |
+
else:
|
| 45 |
+
return self.mel_stft(waveform)
|
musicfm/modules/flash_conformer.py
ADDED
|
@@ -0,0 +1,2114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Wav2Vec2-Conformer model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
from torch.nn import functional as F
|
| 27 |
+
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 30 |
+
from transformers.modeling_outputs import (
|
| 31 |
+
BaseModelOutput,
|
| 32 |
+
CausalLMOutput,
|
| 33 |
+
SequenceClassifierOutput,
|
| 34 |
+
TokenClassifierOutput,
|
| 35 |
+
Wav2Vec2BaseModelOutput,
|
| 36 |
+
XVectorOutput,
|
| 37 |
+
)
|
| 38 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 39 |
+
from transformers.utils import (
|
| 40 |
+
ModelOutput,
|
| 41 |
+
add_code_sample_docstrings,
|
| 42 |
+
add_start_docstrings,
|
| 43 |
+
add_start_docstrings_to_model_forward,
|
| 44 |
+
logging,
|
| 45 |
+
replace_return_docstrings,
|
| 46 |
+
)
|
| 47 |
+
from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_HIDDEN_STATES_START_POSITION = 2
|
| 54 |
+
|
| 55 |
+
# General docstring
|
| 56 |
+
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
|
| 57 |
+
|
| 58 |
+
# Base docstring
|
| 59 |
+
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
|
| 60 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
| 61 |
+
|
| 62 |
+
# CTC docstring
|
| 63 |
+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 64 |
+
_CTC_EXPECTED_LOSS = 64.21
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 68 |
+
"facebook/wav2vec2-conformer-rel-pos-large",
|
| 69 |
+
# See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
|
| 75 |
+
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
|
| 76 |
+
"""
|
| 77 |
+
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 81 |
+
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
|
| 82 |
+
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
|
| 83 |
+
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 84 |
+
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
|
| 85 |
+
projected quantized states.
|
| 86 |
+
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 87 |
+
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
|
| 88 |
+
target vectors for contrastive loss.
|
| 89 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 90 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 91 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 92 |
+
|
| 93 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 94 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 95 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 96 |
+
sequence_length)`.
|
| 97 |
+
|
| 98 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 99 |
+
heads.
|
| 100 |
+
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 101 |
+
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 102 |
+
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 103 |
+
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
loss: Optional[torch.FloatTensor] = None
|
| 107 |
+
projected_states: torch.FloatTensor = None
|
| 108 |
+
projected_quantized_states: torch.FloatTensor = None
|
| 109 |
+
codevector_perplexity: torch.FloatTensor = None
|
| 110 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 111 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 112 |
+
contrastive_loss: Optional[torch.FloatTensor] = None
|
| 113 |
+
diversity_loss: Optional[torch.FloatTensor] = None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
| 117 |
+
def _compute_mask_indices(
|
| 118 |
+
shape: Tuple[int, int],
|
| 119 |
+
mask_prob: float,
|
| 120 |
+
mask_length: int,
|
| 121 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 122 |
+
min_masks: int = 0,
|
| 123 |
+
) -> np.ndarray:
|
| 124 |
+
"""
|
| 125 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 126 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 127 |
+
CPU as part of the preprocessing during training.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 131 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
| 132 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 133 |
+
independently generated mask spans of length `mask_length` is computed by
|
| 134 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 135 |
+
actual percentage will be smaller.
|
| 136 |
+
mask_length: size of the mask
|
| 137 |
+
min_masks: minimum number of masked spans
|
| 138 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 139 |
+
each batch dimension.
|
| 140 |
+
"""
|
| 141 |
+
batch_size, sequence_length = shape
|
| 142 |
+
|
| 143 |
+
if mask_length < 1:
|
| 144 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 145 |
+
|
| 146 |
+
if mask_length > sequence_length:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 149 |
+
f" and `sequence_length`: {sequence_length}`"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# epsilon is used for probabilistic rounding
|
| 153 |
+
epsilon = np.random.rand(1).item()
|
| 154 |
+
|
| 155 |
+
def compute_num_masked_span(input_length):
|
| 156 |
+
"""Given input length, compute how many spans should be masked"""
|
| 157 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 158 |
+
num_masked_span = max(num_masked_span, min_masks)
|
| 159 |
+
|
| 160 |
+
# make sure num masked span <= sequence_length
|
| 161 |
+
if num_masked_span * mask_length > sequence_length:
|
| 162 |
+
num_masked_span = sequence_length // mask_length
|
| 163 |
+
|
| 164 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 165 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
| 166 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 167 |
+
|
| 168 |
+
return num_masked_span
|
| 169 |
+
|
| 170 |
+
# compute number of masked spans in batch
|
| 171 |
+
input_lengths = (
|
| 172 |
+
attention_mask.sum(-1).detach().tolist()
|
| 173 |
+
if attention_mask is not None
|
| 174 |
+
else [sequence_length for _ in range(batch_size)]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# SpecAugment mask to fill
|
| 178 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 179 |
+
spec_aug_mask_idxs = []
|
| 180 |
+
|
| 181 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 182 |
+
|
| 183 |
+
if max_num_masked_span == 0:
|
| 184 |
+
return spec_aug_mask
|
| 185 |
+
|
| 186 |
+
for input_length in input_lengths:
|
| 187 |
+
# compute num of masked spans for this input
|
| 188 |
+
num_masked_span = compute_num_masked_span(input_length)
|
| 189 |
+
|
| 190 |
+
# get random indices to mask
|
| 191 |
+
spec_aug_mask_idx = np.random.choice(
|
| 192 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
| 196 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
| 197 |
+
# Picking first sample just pads those vectors twice.
|
| 198 |
+
if len(spec_aug_mask_idx) == 0:
|
| 199 |
+
# this case can only happen if `input_length` is strictly smaller then
|
| 200 |
+
# `sequence_length` in which case the last token has to be a padding
|
| 201 |
+
# token which we can use as a dummy mask id
|
| 202 |
+
dummy_mask_idx = sequence_length - 1
|
| 203 |
+
else:
|
| 204 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 205 |
+
|
| 206 |
+
spec_aug_mask_idx = np.concatenate(
|
| 207 |
+
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 208 |
+
)
|
| 209 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 210 |
+
|
| 211 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 212 |
+
|
| 213 |
+
# expand masked indices to masked spans
|
| 214 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
| 215 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 216 |
+
)
|
| 217 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 218 |
+
|
| 219 |
+
# add offset to the starting indexes so that indexes now create a span
|
| 220 |
+
offsets = np.arange(mask_length)[None, None, :]
|
| 221 |
+
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 222 |
+
batch_size, max_num_masked_span * mask_length
|
| 223 |
+
)
|
| 224 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 225 |
+
|
| 226 |
+
# ensure that we cannot have indices larger than sequence_length
|
| 227 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 228 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 229 |
+
|
| 230 |
+
# scatter indices to mask
|
| 231 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 232 |
+
|
| 233 |
+
return spec_aug_mask
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
|
| 237 |
+
def _sample_negative_indices(
|
| 238 |
+
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Sample `num_negatives` vectors from feature vectors.
|
| 242 |
+
"""
|
| 243 |
+
batch_size, sequence_length = features_shape
|
| 244 |
+
|
| 245 |
+
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
| 246 |
+
sequence_length_range = np.arange(sequence_length)
|
| 247 |
+
|
| 248 |
+
# get `num_negatives` random vector indices from the same utterance
|
| 249 |
+
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
| 250 |
+
|
| 251 |
+
mask_time_indices = (
|
| 252 |
+
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
for batch_idx in range(batch_size):
|
| 256 |
+
high = mask_time_indices[batch_idx].sum() - 1
|
| 257 |
+
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
| 258 |
+
|
| 259 |
+
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
| 260 |
+
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
| 261 |
+
# avoid sampling the same positive vector, but keep the distribution uniform
|
| 262 |
+
sampled_indices[sampled_indices >= feature_indices] += 1
|
| 263 |
+
|
| 264 |
+
# remap to actual indices
|
| 265 |
+
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
| 266 |
+
|
| 267 |
+
# correct for batch size
|
| 268 |
+
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
| 269 |
+
|
| 270 |
+
return sampled_negative_indices
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 274 |
+
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
|
| 275 |
+
def __init__(self, config, layer_id=0):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 278 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 279 |
+
|
| 280 |
+
self.conv = nn.Conv1d(
|
| 281 |
+
self.in_conv_dim,
|
| 282 |
+
self.out_conv_dim,
|
| 283 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 284 |
+
stride=config.conv_stride[layer_id],
|
| 285 |
+
bias=config.conv_bias,
|
| 286 |
+
)
|
| 287 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 288 |
+
|
| 289 |
+
def forward(self, hidden_states):
|
| 290 |
+
hidden_states = self.conv(hidden_states)
|
| 291 |
+
hidden_states = self.activation(hidden_states)
|
| 292 |
+
return hidden_states
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 296 |
+
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
|
| 297 |
+
def __init__(self, config, layer_id=0):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 300 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 301 |
+
|
| 302 |
+
self.conv = nn.Conv1d(
|
| 303 |
+
self.in_conv_dim,
|
| 304 |
+
self.out_conv_dim,
|
| 305 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 306 |
+
stride=config.conv_stride[layer_id],
|
| 307 |
+
bias=config.conv_bias,
|
| 308 |
+
)
|
| 309 |
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 310 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 311 |
+
|
| 312 |
+
def forward(self, hidden_states):
|
| 313 |
+
hidden_states = self.conv(hidden_states)
|
| 314 |
+
|
| 315 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 316 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 317 |
+
hidden_states = hidden_states.transpose(-2, -1)
|
| 318 |
+
|
| 319 |
+
hidden_states = self.activation(hidden_states)
|
| 320 |
+
return hidden_states
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 324 |
+
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
|
| 325 |
+
def __init__(self, config, layer_id=0):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 328 |
+
self.out_conv_dim = config.conv_dim[layer_id]
|
| 329 |
+
|
| 330 |
+
self.conv = nn.Conv1d(
|
| 331 |
+
self.in_conv_dim,
|
| 332 |
+
self.out_conv_dim,
|
| 333 |
+
kernel_size=config.conv_kernel[layer_id],
|
| 334 |
+
stride=config.conv_stride[layer_id],
|
| 335 |
+
bias=config.conv_bias,
|
| 336 |
+
)
|
| 337 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 338 |
+
|
| 339 |
+
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
| 340 |
+
|
| 341 |
+
def forward(self, hidden_states):
|
| 342 |
+
hidden_states = self.conv(hidden_states)
|
| 343 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 344 |
+
hidden_states = self.activation(hidden_states)
|
| 345 |
+
return hidden_states
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
|
| 349 |
+
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
|
| 350 |
+
def __init__(self, config):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.conv = nn.Conv1d(
|
| 353 |
+
config.hidden_size,
|
| 354 |
+
config.hidden_size,
|
| 355 |
+
kernel_size=config.num_conv_pos_embeddings,
|
| 356 |
+
padding=config.num_conv_pos_embeddings // 2,
|
| 357 |
+
groups=config.num_conv_pos_embedding_groups,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if is_deepspeed_zero3_enabled():
|
| 361 |
+
import deepspeed
|
| 362 |
+
|
| 363 |
+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
| 364 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 365 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
| 366 |
+
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
| 367 |
+
else:
|
| 368 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 369 |
+
|
| 370 |
+
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
|
| 371 |
+
self.activation = ACT2FN[config.feat_extract_activation]
|
| 372 |
+
|
| 373 |
+
def forward(self, hidden_states):
|
| 374 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 375 |
+
|
| 376 |
+
hidden_states = self.conv(hidden_states)
|
| 377 |
+
hidden_states = self.padding(hidden_states)
|
| 378 |
+
hidden_states = self.activation(hidden_states)
|
| 379 |
+
|
| 380 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 381 |
+
return hidden_states
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
| 385 |
+
"""Rotary positional embedding
|
| 386 |
+
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
super().__init__()
|
| 391 |
+
dim = config.hidden_size // config.num_attention_heads
|
| 392 |
+
base = config.rotary_embedding_base
|
| 393 |
+
|
| 394 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 395 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 396 |
+
self.cached_sequence_length = None
|
| 397 |
+
self.cached_rotary_positional_embedding = None
|
| 398 |
+
|
| 399 |
+
def forward(self, hidden_states):
|
| 400 |
+
sequence_length = hidden_states.shape[1]
|
| 401 |
+
|
| 402 |
+
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
| 403 |
+
return self.cached_rotary_positional_embedding
|
| 404 |
+
|
| 405 |
+
self.cached_sequence_length = sequence_length
|
| 406 |
+
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
| 407 |
+
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
| 408 |
+
embeddings = torch.cat((freqs, freqs), dim=-1)
|
| 409 |
+
|
| 410 |
+
cos_embeddings = embeddings.cos()[:, None, None, :]
|
| 411 |
+
sin_embeddings = embeddings.sin()[:, None, None, :]
|
| 412 |
+
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
| 413 |
+
return self.cached_rotary_positional_embedding
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
| 417 |
+
"""Relative positional encoding module."""
|
| 418 |
+
|
| 419 |
+
def __init__(self, config):
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.max_len = config.max_source_positions
|
| 422 |
+
self.d_model = config.hidden_size
|
| 423 |
+
self.pe = None
|
| 424 |
+
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
| 425 |
+
|
| 426 |
+
def extend_pe(self, x):
|
| 427 |
+
# Reset the positional encodings
|
| 428 |
+
if self.pe is not None:
|
| 429 |
+
# self.pe contains both positive and negative parts
|
| 430 |
+
# the length of self.pe is 2 * input_len - 1
|
| 431 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 432 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 433 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 434 |
+
return
|
| 435 |
+
# Suppose `i` is the position of query vector and `j` is the
|
| 436 |
+
# position of key vector. We use positive relative positions when keys
|
| 437 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 438 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 439 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 440 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 441 |
+
div_term = torch.exp(
|
| 442 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
|
| 443 |
+
)
|
| 444 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 445 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 446 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 447 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 448 |
+
|
| 449 |
+
# Reverse the order of positive indices and concat both positive and
|
| 450 |
+
# negative indices. This is used to support the shifting trick
|
| 451 |
+
# as in https://arxiv.org/abs/1901.02860
|
| 452 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 453 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 454 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 455 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 456 |
+
|
| 457 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 458 |
+
self.extend_pe(hidden_states)
|
| 459 |
+
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
| 460 |
+
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
| 461 |
+
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
| 462 |
+
|
| 463 |
+
return relative_position_embeddings
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 467 |
+
class Wav2Vec2ConformerSamePadLayer(nn.Module):
|
| 468 |
+
def __init__(self, num_conv_pos_embeddings):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 471 |
+
|
| 472 |
+
def forward(self, hidden_states):
|
| 473 |
+
if self.num_pad_remove > 0:
|
| 474 |
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 475 |
+
return hidden_states
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
|
| 479 |
+
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
|
| 480 |
+
"""Construct the features from raw audio waveform"""
|
| 481 |
+
|
| 482 |
+
def __init__(self, config):
|
| 483 |
+
super().__init__()
|
| 484 |
+
|
| 485 |
+
if config.feat_extract_norm == "group":
|
| 486 |
+
conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
|
| 487 |
+
Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
|
| 488 |
+
for i in range(config.num_feat_extract_layers - 1)
|
| 489 |
+
]
|
| 490 |
+
elif config.feat_extract_norm == "layer":
|
| 491 |
+
conv_layers = [
|
| 492 |
+
Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
| 493 |
+
]
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(
|
| 496 |
+
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
| 497 |
+
)
|
| 498 |
+
self.conv_layers = nn.ModuleList(conv_layers)
|
| 499 |
+
self.gradient_checkpointing = False
|
| 500 |
+
self._requires_grad = True
|
| 501 |
+
|
| 502 |
+
def _freeze_parameters(self):
|
| 503 |
+
for param in self.parameters():
|
| 504 |
+
param.requires_grad = False
|
| 505 |
+
self._requires_grad = False
|
| 506 |
+
|
| 507 |
+
def forward(self, input_values):
|
| 508 |
+
hidden_states = input_values[:, None]
|
| 509 |
+
|
| 510 |
+
# make sure hidden_states require grad for gradient_checkpointing
|
| 511 |
+
if self._requires_grad and self.training:
|
| 512 |
+
hidden_states.requires_grad = True
|
| 513 |
+
|
| 514 |
+
for conv_layer in self.conv_layers:
|
| 515 |
+
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 516 |
+
|
| 517 |
+
def create_custom_forward(module):
|
| 518 |
+
def custom_forward(*inputs):
|
| 519 |
+
return module(*inputs)
|
| 520 |
+
|
| 521 |
+
return custom_forward
|
| 522 |
+
|
| 523 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 524 |
+
create_custom_forward(conv_layer),
|
| 525 |
+
hidden_states,
|
| 526 |
+
)
|
| 527 |
+
else:
|
| 528 |
+
hidden_states = conv_layer(hidden_states)
|
| 529 |
+
|
| 530 |
+
return hidden_states
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
|
| 534 |
+
class Wav2Vec2ConformerFeatureProjection(nn.Module):
|
| 535 |
+
def __init__(self, config):
|
| 536 |
+
super().__init__()
|
| 537 |
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
| 538 |
+
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 539 |
+
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
| 540 |
+
|
| 541 |
+
def forward(self, hidden_states):
|
| 542 |
+
# non-projected hidden states are needed for quantization
|
| 543 |
+
norm_hidden_states = self.layer_norm(hidden_states)
|
| 544 |
+
hidden_states = self.projection(norm_hidden_states)
|
| 545 |
+
hidden_states = self.dropout(hidden_states)
|
| 546 |
+
return hidden_states, norm_hidden_states
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
|
| 550 |
+
class Wav2Vec2ConformerFeedForward(nn.Module):
|
| 551 |
+
def __init__(self, config):
|
| 552 |
+
super().__init__()
|
| 553 |
+
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
| 554 |
+
|
| 555 |
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 556 |
+
if isinstance(config.hidden_act, str):
|
| 557 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 558 |
+
else:
|
| 559 |
+
self.intermediate_act_fn = config.hidden_act
|
| 560 |
+
|
| 561 |
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 562 |
+
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
| 563 |
+
|
| 564 |
+
def forward(self, hidden_states):
|
| 565 |
+
hidden_states = self.intermediate_dense(hidden_states)
|
| 566 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 567 |
+
hidden_states = self.intermediate_dropout(hidden_states)
|
| 568 |
+
|
| 569 |
+
hidden_states = self.output_dense(hidden_states)
|
| 570 |
+
hidden_states = self.output_dropout(hidden_states)
|
| 571 |
+
return hidden_states
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class Wav2Vec2ConformerConvolutionModule(nn.Module):
|
| 575 |
+
"""Convolution block used in the conformer block"""
|
| 576 |
+
|
| 577 |
+
def __init__(self, config):
|
| 578 |
+
super().__init__()
|
| 579 |
+
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
| 580 |
+
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
| 581 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
| 582 |
+
self.pointwise_conv1 = torch.nn.Conv1d(
|
| 583 |
+
config.hidden_size,
|
| 584 |
+
2 * config.hidden_size,
|
| 585 |
+
kernel_size=1,
|
| 586 |
+
stride=1,
|
| 587 |
+
padding=0,
|
| 588 |
+
bias=False,
|
| 589 |
+
)
|
| 590 |
+
self.glu = torch.nn.GLU(dim=1)
|
| 591 |
+
self.depthwise_conv = torch.nn.Conv1d(
|
| 592 |
+
config.hidden_size,
|
| 593 |
+
config.hidden_size,
|
| 594 |
+
config.conv_depthwise_kernel_size,
|
| 595 |
+
stride=1,
|
| 596 |
+
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
| 597 |
+
groups=config.hidden_size,
|
| 598 |
+
bias=False,
|
| 599 |
+
)
|
| 600 |
+
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
|
| 601 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 602 |
+
self.pointwise_conv2 = torch.nn.Conv1d(
|
| 603 |
+
config.hidden_size,
|
| 604 |
+
config.hidden_size,
|
| 605 |
+
kernel_size=1,
|
| 606 |
+
stride=1,
|
| 607 |
+
padding=0,
|
| 608 |
+
bias=False,
|
| 609 |
+
)
|
| 610 |
+
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
|
| 611 |
+
|
| 612 |
+
def forward(self, hidden_states):
|
| 613 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 614 |
+
# exchange the temporal dimension and the feature dimension
|
| 615 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 616 |
+
|
| 617 |
+
# GLU mechanism
|
| 618 |
+
# => (batch, 2*channel, dim)
|
| 619 |
+
hidden_states = self.pointwise_conv1(hidden_states)
|
| 620 |
+
# => (batch, channel, dim)
|
| 621 |
+
hidden_states = self.glu(hidden_states)
|
| 622 |
+
|
| 623 |
+
# 1D Depthwise Conv
|
| 624 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
| 625 |
+
hidden_states = self.batch_norm(hidden_states)
|
| 626 |
+
hidden_states = self.activation(hidden_states)
|
| 627 |
+
|
| 628 |
+
hidden_states = self.pointwise_conv2(hidden_states)
|
| 629 |
+
hidden_states = self.dropout(hidden_states)
|
| 630 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 631 |
+
return hidden_states
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class Wav2Vec2ConformerSelfAttention(nn.Module):
|
| 635 |
+
"""Construct an Wav2Vec2ConformerSelfAttention object.
|
| 636 |
+
Can be enhanced with rotary or relative position embeddings.
|
| 637 |
+
"""
|
| 638 |
+
|
| 639 |
+
def __init__(self, config):
|
| 640 |
+
super().__init__()
|
| 641 |
+
|
| 642 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
| 643 |
+
self.num_heads = config.num_attention_heads
|
| 644 |
+
self.position_embeddings_type = config.position_embeddings_type
|
| 645 |
+
|
| 646 |
+
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
| 647 |
+
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
| 648 |
+
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
| 649 |
+
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
| 650 |
+
|
| 651 |
+
self.dropout = nn.Dropout(p=config.attention_dropout)
|
| 652 |
+
self.dropout_p = config.attention_dropout
|
| 653 |
+
|
| 654 |
+
self.is_causal = config.is_causal
|
| 655 |
+
|
| 656 |
+
if self.position_embeddings_type == "relative":
|
| 657 |
+
# linear transformation for positional encoding
|
| 658 |
+
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
| 659 |
+
# these two learnable bias are used in matrix c and matrix d
|
| 660 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 661 |
+
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 662 |
+
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 663 |
+
|
| 664 |
+
def forward(
|
| 665 |
+
self,
|
| 666 |
+
hidden_states: torch.Tensor,
|
| 667 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 668 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 669 |
+
output_attentions: bool = False,
|
| 670 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 671 |
+
# self-attention mechanism
|
| 672 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 673 |
+
|
| 674 |
+
# make sure query/key states can be != value states
|
| 675 |
+
query_key_states = hidden_states
|
| 676 |
+
value_states = hidden_states
|
| 677 |
+
|
| 678 |
+
if self.position_embeddings_type == "rotary":
|
| 679 |
+
if relative_position_embeddings is None:
|
| 680 |
+
raise ValueError(
|
| 681 |
+
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
| 682 |
+
)
|
| 683 |
+
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
| 684 |
+
|
| 685 |
+
# project query_key_states and value_states
|
| 686 |
+
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 687 |
+
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 688 |
+
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 689 |
+
|
| 690 |
+
# => (batch, head, time1, d_k)
|
| 691 |
+
query = query.transpose(1, 2)
|
| 692 |
+
key = key.transpose(1, 2)
|
| 693 |
+
value = value.transpose(1, 2)
|
| 694 |
+
|
| 695 |
+
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
| 696 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
|
| 697 |
+
probs = None
|
| 698 |
+
|
| 699 |
+
# # apply attention_mask if necessary
|
| 700 |
+
# if attention_mask is not None:
|
| 701 |
+
# scores = scores + attention_mask
|
| 702 |
+
|
| 703 |
+
# # => (batch, head, time1, time2)
|
| 704 |
+
# probs = torch.softmax(scores, dim=-1)
|
| 705 |
+
# probs = self.dropout(probs)
|
| 706 |
+
|
| 707 |
+
# # => (batch, head, time1, d_k)
|
| 708 |
+
# hidden_states = torch.matmul(probs, value)
|
| 709 |
+
|
| 710 |
+
# => (batch, time1, hidden_size)
|
| 711 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
| 712 |
+
hidden_states = self.linear_out(hidden_states)
|
| 713 |
+
|
| 714 |
+
return hidden_states, probs
|
| 715 |
+
|
| 716 |
+
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
| 717 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 718 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
| 719 |
+
|
| 720 |
+
cos = relative_position_embeddings[0, :sequence_length, ...]
|
| 721 |
+
sin = relative_position_embeddings[1, :sequence_length, ...]
|
| 722 |
+
|
| 723 |
+
# rotate hidden_states with rotary embeddings
|
| 724 |
+
hidden_states = hidden_states.transpose(0, 1)
|
| 725 |
+
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
| 726 |
+
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
| 727 |
+
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
| 728 |
+
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
| 729 |
+
hidden_states = hidden_states.transpose(0, 1)
|
| 730 |
+
|
| 731 |
+
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
| 732 |
+
|
| 733 |
+
return hidden_states
|
| 734 |
+
|
| 735 |
+
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
| 736 |
+
# 1. project positional embeddings
|
| 737 |
+
# => (batch, head, 2*time1-1, d_k)
|
| 738 |
+
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
| 739 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
| 740 |
+
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
|
| 741 |
+
)
|
| 742 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
|
| 743 |
+
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
|
| 744 |
+
|
| 745 |
+
# 2. Add bias to query
|
| 746 |
+
# => (batch, head, time1, d_k)
|
| 747 |
+
query = query.transpose(1, 2)
|
| 748 |
+
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
| 749 |
+
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
| 750 |
+
|
| 751 |
+
# 3. attention score: first compute matrix a and matrix c
|
| 752 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 753 |
+
# => (batch, head, time1, time2)
|
| 754 |
+
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
| 755 |
+
|
| 756 |
+
# 4. then compute matrix b and matrix d
|
| 757 |
+
# => (batch, head, time1, 2*time1-1)
|
| 758 |
+
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
| 759 |
+
|
| 760 |
+
# 5. shift matrix b and matrix d
|
| 761 |
+
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
| 762 |
+
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
| 763 |
+
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
| 764 |
+
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
| 765 |
+
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
| 766 |
+
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
| 767 |
+
|
| 768 |
+
# 6. sum matrices
|
| 769 |
+
# => (batch, head, time1, time2)
|
| 770 |
+
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
| 771 |
+
|
| 772 |
+
return scores
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
class Wav2Vec2ConformerEncoderLayer(nn.Module):
|
| 776 |
+
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
| 777 |
+
|
| 778 |
+
def __init__(self, config):
|
| 779 |
+
super().__init__()
|
| 780 |
+
embed_dim = config.hidden_size
|
| 781 |
+
dropout = config.attention_dropout
|
| 782 |
+
|
| 783 |
+
# Feed-forward 1
|
| 784 |
+
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
| 785 |
+
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
|
| 786 |
+
|
| 787 |
+
# Self-Attention
|
| 788 |
+
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
| 789 |
+
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
| 790 |
+
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
|
| 791 |
+
|
| 792 |
+
# Conformer Convolution
|
| 793 |
+
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
|
| 794 |
+
|
| 795 |
+
# Feed-forward 2
|
| 796 |
+
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
| 797 |
+
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
|
| 798 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
| 799 |
+
|
| 800 |
+
def forward(
|
| 801 |
+
self,
|
| 802 |
+
hidden_states,
|
| 803 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 804 |
+
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 805 |
+
output_attentions: bool = False,
|
| 806 |
+
):
|
| 807 |
+
hidden_states = hidden_states
|
| 808 |
+
|
| 809 |
+
# 1. Feed-Forward 1 layer
|
| 810 |
+
residual = hidden_states
|
| 811 |
+
hidden_states = self.ffn1_layer_norm(hidden_states)
|
| 812 |
+
hidden_states = self.ffn1(hidden_states)
|
| 813 |
+
hidden_states = hidden_states * 0.5 + residual
|
| 814 |
+
residual = hidden_states
|
| 815 |
+
|
| 816 |
+
# 2. Self-Attention layer
|
| 817 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 818 |
+
hidden_states, attn_weigts = self.self_attn(
|
| 819 |
+
hidden_states=hidden_states,
|
| 820 |
+
attention_mask=attention_mask,
|
| 821 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 822 |
+
output_attentions=output_attentions,
|
| 823 |
+
)
|
| 824 |
+
hidden_states = self.self_attn_dropout(hidden_states)
|
| 825 |
+
hidden_states = hidden_states + residual
|
| 826 |
+
|
| 827 |
+
# 3. Convolutional Layer
|
| 828 |
+
residual = hidden_states
|
| 829 |
+
hidden_states = self.conv_module(hidden_states)
|
| 830 |
+
hidden_states = residual + hidden_states
|
| 831 |
+
|
| 832 |
+
# 4. Feed-Forward 2 Layer
|
| 833 |
+
residual = hidden_states
|
| 834 |
+
hidden_states = self.ffn2_layer_norm(hidden_states)
|
| 835 |
+
hidden_states = self.ffn2(hidden_states)
|
| 836 |
+
hidden_states = hidden_states * 0.5 + residual
|
| 837 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 838 |
+
|
| 839 |
+
return hidden_states, attn_weigts
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class Wav2Vec2ConformerEncoder(nn.Module):
|
| 843 |
+
def __init__(self, config, is_causal=False):
|
| 844 |
+
super().__init__()
|
| 845 |
+
config.is_causal = is_causal
|
| 846 |
+
self.config = config
|
| 847 |
+
|
| 848 |
+
if config.position_embeddings_type == "relative":
|
| 849 |
+
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
|
| 850 |
+
elif config.position_embeddings_type == "rotary":
|
| 851 |
+
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
|
| 852 |
+
else:
|
| 853 |
+
self.embed_positions = None
|
| 854 |
+
|
| 855 |
+
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
|
| 856 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 857 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 858 |
+
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 859 |
+
self.gradient_checkpointing = False
|
| 860 |
+
|
| 861 |
+
def forward(
|
| 862 |
+
self,
|
| 863 |
+
hidden_states,
|
| 864 |
+
attention_mask=None,
|
| 865 |
+
output_attentions=False,
|
| 866 |
+
output_hidden_states=False,
|
| 867 |
+
return_dict=True,
|
| 868 |
+
):
|
| 869 |
+
all_hidden_states = () if output_hidden_states else None
|
| 870 |
+
all_self_attentions = () if output_attentions else None
|
| 871 |
+
|
| 872 |
+
if attention_mask is not None:
|
| 873 |
+
# make sure padded tokens output 0
|
| 874 |
+
hidden_states[~attention_mask] = 0.0
|
| 875 |
+
|
| 876 |
+
# extend attention_mask
|
| 877 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
| 878 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 879 |
+
attention_mask = attention_mask.expand(
|
| 880 |
+
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
hidden_states = self.dropout(hidden_states)
|
| 884 |
+
|
| 885 |
+
if self.embed_positions is not None:
|
| 886 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
| 887 |
+
else:
|
| 888 |
+
relative_position_embeddings = None
|
| 889 |
+
|
| 890 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
| 891 |
+
|
| 892 |
+
for i, layer in enumerate(self.layers):
|
| 893 |
+
if output_hidden_states:
|
| 894 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 895 |
+
|
| 896 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 897 |
+
dropout_probability = np.random.uniform(0, 1)
|
| 898 |
+
|
| 899 |
+
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 900 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
| 901 |
+
# under deepspeed zero3 all gpus must run in sync
|
| 902 |
+
if self.gradient_checkpointing and self.training:
|
| 903 |
+
# create gradient checkpointing function
|
| 904 |
+
def create_custom_forward(module):
|
| 905 |
+
def custom_forward(*inputs):
|
| 906 |
+
return module(*inputs, output_attentions)
|
| 907 |
+
|
| 908 |
+
return custom_forward
|
| 909 |
+
|
| 910 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 911 |
+
create_custom_forward(layer),
|
| 912 |
+
hidden_states,
|
| 913 |
+
attention_mask,
|
| 914 |
+
relative_position_embeddings,
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
layer_outputs = layer(
|
| 918 |
+
hidden_states,
|
| 919 |
+
attention_mask=attention_mask,
|
| 920 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 921 |
+
output_attentions=output_attentions,
|
| 922 |
+
)
|
| 923 |
+
hidden_states = layer_outputs[0]
|
| 924 |
+
|
| 925 |
+
if skip_the_layer:
|
| 926 |
+
layer_outputs = (None, None)
|
| 927 |
+
|
| 928 |
+
if output_attentions:
|
| 929 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 930 |
+
|
| 931 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 932 |
+
if output_hidden_states:
|
| 933 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 934 |
+
|
| 935 |
+
if not return_dict:
|
| 936 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 937 |
+
return BaseModelOutput(
|
| 938 |
+
last_hidden_state=hidden_states,
|
| 939 |
+
hidden_states=all_hidden_states,
|
| 940 |
+
attentions=all_self_attentions,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
|
| 945 |
+
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
|
| 946 |
+
"""
|
| 947 |
+
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
|
| 948 |
+
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
| 949 |
+
"""
|
| 950 |
+
|
| 951 |
+
def __init__(self, config):
|
| 952 |
+
super().__init__()
|
| 953 |
+
self.num_groups = config.num_codevector_groups
|
| 954 |
+
self.num_vars = config.num_codevectors_per_group
|
| 955 |
+
|
| 956 |
+
if config.codevector_dim % self.num_groups != 0:
|
| 957 |
+
raise ValueError(
|
| 958 |
+
f"`config.codevector_dim {config.codevector_dim} must be divisible "
|
| 959 |
+
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# storage for codebook variables (codewords)
|
| 963 |
+
self.codevectors = nn.Parameter(
|
| 964 |
+
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
| 965 |
+
)
|
| 966 |
+
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
| 967 |
+
|
| 968 |
+
# can be decayed for training
|
| 969 |
+
self.temperature = 2
|
| 970 |
+
|
| 971 |
+
@staticmethod
|
| 972 |
+
def _compute_perplexity(probs, mask=None):
|
| 973 |
+
if mask is not None:
|
| 974 |
+
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
|
| 975 |
+
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
|
| 976 |
+
marginal_probs = probs.sum(dim=0) / mask.sum()
|
| 977 |
+
else:
|
| 978 |
+
marginal_probs = probs.mean(dim=0)
|
| 979 |
+
|
| 980 |
+
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
| 981 |
+
return perplexity
|
| 982 |
+
|
| 983 |
+
def forward(self, hidden_states, mask_time_indices=None):
|
| 984 |
+
batch_size, sequence_length, hidden_size = hidden_states.shape
|
| 985 |
+
|
| 986 |
+
# project to codevector dim
|
| 987 |
+
hidden_states = self.weight_proj(hidden_states)
|
| 988 |
+
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
| 989 |
+
|
| 990 |
+
if self.training:
|
| 991 |
+
# sample code vector probs via gumbel in differentiateable way
|
| 992 |
+
codevector_probs = nn.functional.gumbel_softmax(
|
| 993 |
+
hidden_states.float(), tau=self.temperature, hard=True
|
| 994 |
+
).type_as(hidden_states)
|
| 995 |
+
|
| 996 |
+
# compute perplexity
|
| 997 |
+
codevector_soft_dist = torch.softmax(
|
| 998 |
+
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
| 999 |
+
)
|
| 1000 |
+
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
| 1001 |
+
else:
|
| 1002 |
+
# take argmax in non-differentiable way
|
| 1003 |
+
# comptute hard codevector distribution (one hot)
|
| 1004 |
+
codevector_idx = hidden_states.argmax(dim=-1)
|
| 1005 |
+
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
| 1006 |
+
-1, codevector_idx.view(-1, 1), 1.0
|
| 1007 |
+
)
|
| 1008 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
| 1009 |
+
|
| 1010 |
+
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
| 1011 |
+
|
| 1012 |
+
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
| 1013 |
+
# use probs to retrieve codevectors
|
| 1014 |
+
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
| 1015 |
+
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
| 1016 |
+
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
| 1017 |
+
|
| 1018 |
+
return codevectors, perplexity
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
|
| 1022 |
+
class Wav2Vec2ConformerAdapter(nn.Module):
|
| 1023 |
+
def __init__(self, config):
|
| 1024 |
+
super().__init__()
|
| 1025 |
+
|
| 1026 |
+
# feature dim might need to be down-projected
|
| 1027 |
+
if config.output_hidden_size != config.hidden_size:
|
| 1028 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
| 1029 |
+
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
| 1030 |
+
else:
|
| 1031 |
+
self.proj = self.proj_layer_norm = None
|
| 1032 |
+
|
| 1033 |
+
self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
| 1034 |
+
self.layerdrop = config.layerdrop
|
| 1035 |
+
|
| 1036 |
+
def forward(self, hidden_states):
|
| 1037 |
+
# down project hidden_states if necessary
|
| 1038 |
+
if self.proj is not None and self.proj_layer_norm is not None:
|
| 1039 |
+
hidden_states = self.proj(hidden_states)
|
| 1040 |
+
hidden_states = self.proj_layer_norm(hidden_states)
|
| 1041 |
+
|
| 1042 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1043 |
+
|
| 1044 |
+
for layer in self.layers:
|
| 1045 |
+
layerdrop_prob = np.random.random()
|
| 1046 |
+
if not self.training or (layerdrop_prob > self.layerdrop):
|
| 1047 |
+
hidden_states = layer(hidden_states)
|
| 1048 |
+
|
| 1049 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1050 |
+
return hidden_states
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 1054 |
+
class Wav2Vec2ConformerAdapterLayer(nn.Module):
|
| 1055 |
+
def __init__(self, config):
|
| 1056 |
+
super().__init__()
|
| 1057 |
+
self.conv = nn.Conv1d(
|
| 1058 |
+
config.output_hidden_size,
|
| 1059 |
+
2 * config.output_hidden_size,
|
| 1060 |
+
config.adapter_kernel_size,
|
| 1061 |
+
stride=config.adapter_stride,
|
| 1062 |
+
padding=1,
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
def forward(self, hidden_states):
|
| 1066 |
+
hidden_states = self.conv(hidden_states)
|
| 1067 |
+
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
| 1068 |
+
|
| 1069 |
+
return hidden_states
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
| 1073 |
+
"""
|
| 1074 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1075 |
+
models.
|
| 1076 |
+
"""
|
| 1077 |
+
|
| 1078 |
+
config_class = Wav2Vec2ConformerConfig
|
| 1079 |
+
base_model_prefix = "wav2vec2_conformer"
|
| 1080 |
+
main_input_name = "input_values"
|
| 1081 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 1082 |
+
supports_gradient_checkpointing = True
|
| 1083 |
+
|
| 1084 |
+
def _init_weights(self, module):
|
| 1085 |
+
"""Initialize the weights"""
|
| 1086 |
+
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
| 1087 |
+
if isinstance(module, Wav2Vec2ConformerForPreTraining):
|
| 1088 |
+
module.project_hid.reset_parameters()
|
| 1089 |
+
module.project_q.reset_parameters()
|
| 1090 |
+
module.project_hid._is_hf_initialized = True
|
| 1091 |
+
module.project_q._is_hf_initialized = True
|
| 1092 |
+
# gumbel softmax requires special init
|
| 1093 |
+
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
| 1094 |
+
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
| 1095 |
+
module.weight_proj.bias.data.zero_()
|
| 1096 |
+
nn.init.uniform_(module.codevectors)
|
| 1097 |
+
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
|
| 1098 |
+
if hasattr(module, "pos_bias_u"):
|
| 1099 |
+
nn.init.xavier_uniform_(module.pos_bias_u)
|
| 1100 |
+
if hasattr(module, "pos_bias_v"):
|
| 1101 |
+
nn.init.xavier_uniform_(module.pos_bias_v)
|
| 1102 |
+
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
|
| 1103 |
+
nn.init.normal_(
|
| 1104 |
+
module.conv.weight,
|
| 1105 |
+
mean=0,
|
| 1106 |
+
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
| 1107 |
+
)
|
| 1108 |
+
nn.init.constant_(module.conv.bias, 0)
|
| 1109 |
+
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
|
| 1110 |
+
k = math.sqrt(1 / module.projection.in_features)
|
| 1111 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
| 1112 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
| 1113 |
+
elif isinstance(module, nn.Linear):
|
| 1114 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1115 |
+
|
| 1116 |
+
if module.bias is not None:
|
| 1117 |
+
module.bias.data.zero_()
|
| 1118 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 1119 |
+
module.bias.data.zero_()
|
| 1120 |
+
module.weight.data.fill_(1.0)
|
| 1121 |
+
elif isinstance(module, nn.Conv1d):
|
| 1122 |
+
nn.init.kaiming_normal_(module.weight)
|
| 1123 |
+
|
| 1124 |
+
if module.bias is not None:
|
| 1125 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 1126 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 1127 |
+
|
| 1128 |
+
def _get_feat_extract_output_lengths(
|
| 1129 |
+
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
| 1130 |
+
):
|
| 1131 |
+
"""
|
| 1132 |
+
Computes the output length of the convolutional layers
|
| 1133 |
+
"""
|
| 1134 |
+
|
| 1135 |
+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
| 1136 |
+
|
| 1137 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 1138 |
+
# 1D convolutional layer output length formula taken
|
| 1139 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1140 |
+
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 1141 |
+
|
| 1142 |
+
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 1143 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 1144 |
+
|
| 1145 |
+
if add_adapter:
|
| 1146 |
+
for _ in range(self.config.num_adapter_layers):
|
| 1147 |
+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
| 1148 |
+
|
| 1149 |
+
return input_lengths
|
| 1150 |
+
|
| 1151 |
+
def _get_feature_vector_attention_mask(
|
| 1152 |
+
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
| 1153 |
+
):
|
| 1154 |
+
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
| 1155 |
+
# on inference mode.
|
| 1156 |
+
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
| 1157 |
+
|
| 1158 |
+
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
| 1159 |
+
output_lengths = output_lengths.to(torch.long)
|
| 1160 |
+
|
| 1161 |
+
batch_size = attention_mask.shape[0]
|
| 1162 |
+
|
| 1163 |
+
attention_mask = torch.zeros(
|
| 1164 |
+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 1165 |
+
)
|
| 1166 |
+
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 1167 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 1168 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 1169 |
+
return attention_mask
|
| 1170 |
+
|
| 1171 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 1172 |
+
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
|
| 1173 |
+
module.gradient_checkpointing = value
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
|
| 1177 |
+
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
| 1178 |
+
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
| 1179 |
+
Auli.
|
| 1180 |
+
|
| 1181 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1182 |
+
library implements for all its model (such as downloading or saving etc.).
|
| 1183 |
+
|
| 1184 |
+
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
|
| 1185 |
+
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 1186 |
+
|
| 1187 |
+
Parameters:
|
| 1188 |
+
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
|
| 1189 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 1190 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1191 |
+
"""
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
|
| 1195 |
+
Args:
|
| 1196 |
+
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1197 |
+
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
| 1198 |
+
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
| 1199 |
+
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
| 1200 |
+
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1201 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1202 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1203 |
+
1]`:
|
| 1204 |
+
|
| 1205 |
+
- 1 for tokens that are **not masked**,
|
| 1206 |
+
- 0 for tokens that are **masked**.
|
| 1207 |
+
|
| 1208 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1209 |
+
|
| 1210 |
+
<Tip warning={true}>
|
| 1211 |
+
|
| 1212 |
+
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
|
| 1213 |
+
True`. For all models whose processor has `config.return_attention_mask == False`, such as
|
| 1214 |
+
[wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
|
| 1215 |
+
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
|
| 1216 |
+
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
|
| 1217 |
+
that these models also yield slightly different results depending on whether `input_values` is padded or
|
| 1218 |
+
not.
|
| 1219 |
+
|
| 1220 |
+
</Tip>
|
| 1221 |
+
|
| 1222 |
+
output_attentions (`bool`, *optional*):
|
| 1223 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1224 |
+
tensors for more detail.
|
| 1225 |
+
output_hidden_states (`bool`, *optional*):
|
| 1226 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1227 |
+
more detail.
|
| 1228 |
+
return_dict (`bool`, *optional*):
|
| 1229 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1230 |
+
"""
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
@add_start_docstrings(
|
| 1234 |
+
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1235 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1236 |
+
)
|
| 1237 |
+
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
|
| 1238 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1239 |
+
super().__init__(config)
|
| 1240 |
+
self.config = config
|
| 1241 |
+
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
|
| 1242 |
+
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
|
| 1243 |
+
|
| 1244 |
+
# model only needs masking vector if mask prob is > 0.0
|
| 1245 |
+
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1246 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
| 1247 |
+
|
| 1248 |
+
self.encoder = Wav2Vec2ConformerEncoder(config)
|
| 1249 |
+
|
| 1250 |
+
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
|
| 1251 |
+
|
| 1252 |
+
# Initialize weights and apply final processing
|
| 1253 |
+
self.post_init()
|
| 1254 |
+
|
| 1255 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
|
| 1256 |
+
def freeze_feature_encoder(self):
|
| 1257 |
+
"""
|
| 1258 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1259 |
+
not be updated during training.
|
| 1260 |
+
"""
|
| 1261 |
+
self.feature_extractor._freeze_parameters()
|
| 1262 |
+
|
| 1263 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
| 1264 |
+
def _mask_hidden_states(
|
| 1265 |
+
self,
|
| 1266 |
+
hidden_states: torch.FloatTensor,
|
| 1267 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1268 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1269 |
+
):
|
| 1270 |
+
"""
|
| 1271 |
+
Masks extracted features along time axis and/or along feature axis according to
|
| 1272 |
+
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1273 |
+
"""
|
| 1274 |
+
|
| 1275 |
+
# `config.apply_spec_augment` can set masking to False
|
| 1276 |
+
if not getattr(self.config, "apply_spec_augment", True):
|
| 1277 |
+
return hidden_states
|
| 1278 |
+
|
| 1279 |
+
# generate indices & apply SpecAugment along time axis
|
| 1280 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1281 |
+
|
| 1282 |
+
if mask_time_indices is not None:
|
| 1283 |
+
# apply SpecAugment along time axis with given mask_time_indices
|
| 1284 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1285 |
+
elif self.config.mask_time_prob > 0 and self.training:
|
| 1286 |
+
mask_time_indices = _compute_mask_indices(
|
| 1287 |
+
(batch_size, sequence_length),
|
| 1288 |
+
mask_prob=self.config.mask_time_prob,
|
| 1289 |
+
mask_length=self.config.mask_time_length,
|
| 1290 |
+
attention_mask=attention_mask,
|
| 1291 |
+
min_masks=self.config.mask_time_min_masks,
|
| 1292 |
+
)
|
| 1293 |
+
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1294 |
+
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1295 |
+
|
| 1296 |
+
if self.config.mask_feature_prob > 0 and self.training:
|
| 1297 |
+
# generate indices & apply SpecAugment along feature axis
|
| 1298 |
+
mask_feature_indices = _compute_mask_indices(
|
| 1299 |
+
(batch_size, hidden_size),
|
| 1300 |
+
mask_prob=self.config.mask_feature_prob,
|
| 1301 |
+
mask_length=self.config.mask_feature_length,
|
| 1302 |
+
min_masks=self.config.mask_feature_min_masks,
|
| 1303 |
+
)
|
| 1304 |
+
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1305 |
+
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1306 |
+
hidden_states[mask_feature_indices] = 0
|
| 1307 |
+
|
| 1308 |
+
return hidden_states
|
| 1309 |
+
|
| 1310 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1311 |
+
@add_code_sample_docstrings(
|
| 1312 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1313 |
+
output_type=Wav2Vec2BaseModelOutput,
|
| 1314 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1315 |
+
modality="audio",
|
| 1316 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1317 |
+
)
|
| 1318 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
|
| 1319 |
+
def forward(
|
| 1320 |
+
self,
|
| 1321 |
+
input_values: Optional[torch.Tensor],
|
| 1322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1323 |
+
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1324 |
+
output_attentions: Optional[bool] = None,
|
| 1325 |
+
output_hidden_states: Optional[bool] = None,
|
| 1326 |
+
return_dict: Optional[bool] = None,
|
| 1327 |
+
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
| 1328 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1329 |
+
output_hidden_states = (
|
| 1330 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1331 |
+
)
|
| 1332 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1333 |
+
|
| 1334 |
+
extract_features = self.feature_extractor(input_values)
|
| 1335 |
+
extract_features = extract_features.transpose(1, 2)
|
| 1336 |
+
|
| 1337 |
+
if attention_mask is not None:
|
| 1338 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 1339 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1340 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1341 |
+
)
|
| 1342 |
+
|
| 1343 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 1344 |
+
hidden_states = self._mask_hidden_states(
|
| 1345 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
encoder_outputs = self.encoder(
|
| 1349 |
+
hidden_states,
|
| 1350 |
+
attention_mask=attention_mask,
|
| 1351 |
+
output_attentions=output_attentions,
|
| 1352 |
+
output_hidden_states=output_hidden_states,
|
| 1353 |
+
return_dict=return_dict,
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
hidden_states = encoder_outputs[0]
|
| 1357 |
+
|
| 1358 |
+
if self.adapter is not None:
|
| 1359 |
+
hidden_states = self.adapter(hidden_states)
|
| 1360 |
+
|
| 1361 |
+
if not return_dict:
|
| 1362 |
+
return (hidden_states, extract_features) + encoder_outputs[1:]
|
| 1363 |
+
|
| 1364 |
+
return Wav2Vec2BaseModelOutput(
|
| 1365 |
+
last_hidden_state=hidden_states,
|
| 1366 |
+
extract_features=extract_features,
|
| 1367 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1368 |
+
attentions=encoder_outputs.attentions,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
@add_start_docstrings(
|
| 1373 |
+
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
|
| 1374 |
+
)
|
| 1375 |
+
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
| 1376 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1377 |
+
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1378 |
+
super().__init__(config)
|
| 1379 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1380 |
+
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
| 1381 |
+
|
| 1382 |
+
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
| 1383 |
+
|
| 1384 |
+
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
| 1385 |
+
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
| 1386 |
+
|
| 1387 |
+
# Initialize weights and apply final processing
|
| 1388 |
+
self.post_init()
|
| 1389 |
+
|
| 1390 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
| 1391 |
+
def set_gumbel_temperature(self, temperature: int):
|
| 1392 |
+
"""
|
| 1393 |
+
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
| 1394 |
+
"""
|
| 1395 |
+
self.quantizer.temperature = temperature
|
| 1396 |
+
|
| 1397 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1398 |
+
def freeze_feature_encoder(self):
|
| 1399 |
+
"""
|
| 1400 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1401 |
+
not be updated during training.
|
| 1402 |
+
"""
|
| 1403 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1404 |
+
|
| 1405 |
+
@staticmethod
|
| 1406 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
|
| 1407 |
+
def compute_contrastive_logits(
|
| 1408 |
+
target_features: torch.FloatTensor,
|
| 1409 |
+
negative_features: torch.FloatTensor,
|
| 1410 |
+
predicted_features: torch.FloatTensor,
|
| 1411 |
+
temperature: int = 0.1,
|
| 1412 |
+
):
|
| 1413 |
+
"""
|
| 1414 |
+
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
| 1415 |
+
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
| 1416 |
+
"""
|
| 1417 |
+
target_features = torch.cat([target_features, negative_features], dim=0)
|
| 1418 |
+
|
| 1419 |
+
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
|
| 1420 |
+
target_features
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
# apply temperature
|
| 1424 |
+
logits = logits / temperature
|
| 1425 |
+
return logits
|
| 1426 |
+
|
| 1427 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1428 |
+
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1429 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
| 1430 |
+
def forward(
|
| 1431 |
+
self,
|
| 1432 |
+
input_values: Optional[torch.Tensor],
|
| 1433 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1434 |
+
mask_time_indices: Optional[torch.BoolTensor] = None,
|
| 1435 |
+
sampled_negative_indices: Optional[torch.BoolTensor] = None,
|
| 1436 |
+
output_attentions: Optional[bool] = None,
|
| 1437 |
+
output_hidden_states: Optional[bool] = None,
|
| 1438 |
+
return_dict: Optional[bool] = None,
|
| 1439 |
+
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
|
| 1440 |
+
r"""
|
| 1441 |
+
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1442 |
+
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
| 1443 |
+
masked extracted features in *config.proj_codevector_dim* space.
|
| 1444 |
+
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
|
| 1445 |
+
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
| 1446 |
+
Required input for pre-training.
|
| 1447 |
+
|
| 1448 |
+
Returns:
|
| 1449 |
+
|
| 1450 |
+
Example:
|
| 1451 |
+
|
| 1452 |
+
```python
|
| 1453 |
+
>>> import torch
|
| 1454 |
+
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
| 1455 |
+
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 1456 |
+
... _compute_mask_indices,
|
| 1457 |
+
... _sample_negative_indices,
|
| 1458 |
+
... )
|
| 1459 |
+
>>> from datasets import load_dataset
|
| 1460 |
+
|
| 1461 |
+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1462 |
+
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1463 |
+
|
| 1464 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 1465 |
+
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
| 1466 |
+
|
| 1467 |
+
>>> # compute masked indices
|
| 1468 |
+
>>> batch_size, raw_sequence_length = input_values.shape
|
| 1469 |
+
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
| 1470 |
+
>>> mask_time_indices = _compute_mask_indices(
|
| 1471 |
+
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
| 1472 |
+
... )
|
| 1473 |
+
>>> sampled_negative_indices = _sample_negative_indices(
|
| 1474 |
+
... features_shape=(batch_size, sequence_length),
|
| 1475 |
+
... num_negatives=model.config.num_negatives,
|
| 1476 |
+
... mask_time_indices=mask_time_indices,
|
| 1477 |
+
... )
|
| 1478 |
+
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
| 1479 |
+
>>> sampled_negative_indices = torch.tensor(
|
| 1480 |
+
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
| 1481 |
+
... )
|
| 1482 |
+
|
| 1483 |
+
>>> with torch.no_grad():
|
| 1484 |
+
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
| 1485 |
+
|
| 1486 |
+
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
| 1487 |
+
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
| 1488 |
+
|
| 1489 |
+
>>> # show that cosine similarity is much higher than random
|
| 1490 |
+
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
| 1491 |
+
tensor(True)
|
| 1492 |
+
|
| 1493 |
+
>>> # for contrastive loss training model should be put into train mode
|
| 1494 |
+
>>> model = model.train()
|
| 1495 |
+
>>> loss = model(
|
| 1496 |
+
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
| 1497 |
+
... ).loss
|
| 1498 |
+
```"""
|
| 1499 |
+
|
| 1500 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1501 |
+
|
| 1502 |
+
if mask_time_indices is not None:
|
| 1503 |
+
mask_time_indices = mask_time_indices.to(torch.bool)
|
| 1504 |
+
|
| 1505 |
+
outputs = self.wav2vec2_conformer(
|
| 1506 |
+
input_values,
|
| 1507 |
+
attention_mask=attention_mask,
|
| 1508 |
+
output_attentions=output_attentions,
|
| 1509 |
+
output_hidden_states=output_hidden_states,
|
| 1510 |
+
mask_time_indices=mask_time_indices,
|
| 1511 |
+
return_dict=return_dict,
|
| 1512 |
+
)
|
| 1513 |
+
|
| 1514 |
+
# 1. project all transformed features (including masked) to final vq dim
|
| 1515 |
+
transformer_features = self.project_hid(outputs[0])
|
| 1516 |
+
|
| 1517 |
+
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
| 1518 |
+
extract_features = self.dropout_features(outputs[1])
|
| 1519 |
+
|
| 1520 |
+
if attention_mask is not None:
|
| 1521 |
+
# compute reduced attention_mask correponding to feature vectors
|
| 1522 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 1523 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
quantized_features, codevector_perplexity = self.quantizer(
|
| 1527 |
+
extract_features, mask_time_indices=mask_time_indices
|
| 1528 |
+
)
|
| 1529 |
+
quantized_features = self.project_q(quantized_features)
|
| 1530 |
+
|
| 1531 |
+
loss = contrastive_loss = diversity_loss = None
|
| 1532 |
+
if sampled_negative_indices is not None:
|
| 1533 |
+
batch_size, sequence_length, hidden_size = quantized_features.shape
|
| 1534 |
+
|
| 1535 |
+
# for training, we sample negatives
|
| 1536 |
+
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
| 1537 |
+
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
| 1538 |
+
# sample negative quantized vectors BTC => (BxT)C
|
| 1539 |
+
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
| 1540 |
+
sampled_negative_indices.long().view(-1)
|
| 1541 |
+
]
|
| 1542 |
+
negative_quantized_features = negative_quantized_features.view(
|
| 1543 |
+
batch_size, sequence_length, -1, hidden_size
|
| 1544 |
+
).permute(2, 0, 1, 3)
|
| 1545 |
+
|
| 1546 |
+
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
| 1547 |
+
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
| 1548 |
+
logits = self.compute_contrastive_logits(
|
| 1549 |
+
quantized_features[None, :],
|
| 1550 |
+
negative_quantized_features,
|
| 1551 |
+
transformer_features,
|
| 1552 |
+
self.config.contrastive_logits_temperature,
|
| 1553 |
+
)
|
| 1554 |
+
|
| 1555 |
+
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
| 1556 |
+
# its cosine similarity will be masked
|
| 1557 |
+
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
| 1558 |
+
|
| 1559 |
+
if neg_is_pos.any():
|
| 1560 |
+
logits[1:][neg_is_pos] = float("-inf")
|
| 1561 |
+
|
| 1562 |
+
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
| 1563 |
+
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
| 1564 |
+
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
| 1565 |
+
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
| 1566 |
+
|
| 1567 |
+
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
| 1568 |
+
# 7. compute diversity loss: \mathbf{L}_d
|
| 1569 |
+
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
| 1570 |
+
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
| 1571 |
+
|
| 1572 |
+
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
| 1573 |
+
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
| 1574 |
+
|
| 1575 |
+
if not return_dict:
|
| 1576 |
+
if loss is not None:
|
| 1577 |
+
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1578 |
+
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1579 |
+
|
| 1580 |
+
return Wav2Vec2ConformerForPreTrainingOutput(
|
| 1581 |
+
loss=loss,
|
| 1582 |
+
projected_states=transformer_features,
|
| 1583 |
+
projected_quantized_states=quantized_features,
|
| 1584 |
+
codevector_perplexity=codevector_perplexity,
|
| 1585 |
+
hidden_states=outputs.hidden_states,
|
| 1586 |
+
attentions=outputs.attentions,
|
| 1587 |
+
contrastive_loss=contrastive_loss,
|
| 1588 |
+
diversity_loss=diversity_loss,
|
| 1589 |
+
)
|
| 1590 |
+
|
| 1591 |
+
|
| 1592 |
+
@add_start_docstrings(
|
| 1593 |
+
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1594 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1595 |
+
)
|
| 1596 |
+
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
| 1597 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1598 |
+
def __init__(self, config):
|
| 1599 |
+
super().__init__(config)
|
| 1600 |
+
|
| 1601 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1602 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
| 1603 |
+
|
| 1604 |
+
if config.vocab_size is None:
|
| 1605 |
+
raise ValueError(
|
| 1606 |
+
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1607 |
+
"does not define the vocabulary size of the language model head. Please "
|
| 1608 |
+
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1609 |
+
"or define `vocab_size` of your model's configuration."
|
| 1610 |
+
)
|
| 1611 |
+
output_hidden_size = (
|
| 1612 |
+
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1613 |
+
)
|
| 1614 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1615 |
+
|
| 1616 |
+
# Initialize weights and apply final processing
|
| 1617 |
+
self.post_init()
|
| 1618 |
+
|
| 1619 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1620 |
+
def freeze_feature_encoder(self):
|
| 1621 |
+
"""
|
| 1622 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1623 |
+
not be updated during training.
|
| 1624 |
+
"""
|
| 1625 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1626 |
+
|
| 1627 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1628 |
+
@add_code_sample_docstrings(
|
| 1629 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1630 |
+
output_type=CausalLMOutput,
|
| 1631 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1632 |
+
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1633 |
+
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1634 |
+
)
|
| 1635 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1636 |
+
def forward(
|
| 1637 |
+
self,
|
| 1638 |
+
input_values: Optional[torch.Tensor],
|
| 1639 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1640 |
+
output_attentions: Optional[bool] = None,
|
| 1641 |
+
output_hidden_states: Optional[bool] = None,
|
| 1642 |
+
return_dict: Optional[bool] = None,
|
| 1643 |
+
labels: Optional[torch.Tensor] = None,
|
| 1644 |
+
) -> Union[Tuple, CausalLMOutput]:
|
| 1645 |
+
r"""
|
| 1646 |
+
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1647 |
+
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1648 |
+
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1649 |
+
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1650 |
+
config.vocab_size - 1]`.
|
| 1651 |
+
"""
|
| 1652 |
+
|
| 1653 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1654 |
+
|
| 1655 |
+
outputs = self.wav2vec2_conformer(
|
| 1656 |
+
input_values,
|
| 1657 |
+
attention_mask=attention_mask,
|
| 1658 |
+
output_attentions=output_attentions,
|
| 1659 |
+
output_hidden_states=output_hidden_states,
|
| 1660 |
+
return_dict=return_dict,
|
| 1661 |
+
)
|
| 1662 |
+
|
| 1663 |
+
hidden_states = outputs[0]
|
| 1664 |
+
hidden_states = self.dropout(hidden_states)
|
| 1665 |
+
|
| 1666 |
+
logits = self.lm_head(hidden_states)
|
| 1667 |
+
|
| 1668 |
+
loss = None
|
| 1669 |
+
if labels is not None:
|
| 1670 |
+
if labels.max() >= self.config.vocab_size:
|
| 1671 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1672 |
+
|
| 1673 |
+
# retrieve loss input_lengths from attention_mask
|
| 1674 |
+
attention_mask = (
|
| 1675 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1676 |
+
)
|
| 1677 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1678 |
+
|
| 1679 |
+
# assuming that padded tokens are filled with -100
|
| 1680 |
+
# when not being attended to
|
| 1681 |
+
labels_mask = labels >= 0
|
| 1682 |
+
target_lengths = labels_mask.sum(-1)
|
| 1683 |
+
flattened_targets = labels.masked_select(labels_mask)
|
| 1684 |
+
|
| 1685 |
+
# ctc_loss doesn't support fp16
|
| 1686 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1687 |
+
|
| 1688 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 1689 |
+
loss = nn.functional.ctc_loss(
|
| 1690 |
+
log_probs,
|
| 1691 |
+
flattened_targets,
|
| 1692 |
+
input_lengths,
|
| 1693 |
+
target_lengths,
|
| 1694 |
+
blank=self.config.pad_token_id,
|
| 1695 |
+
reduction=self.config.ctc_loss_reduction,
|
| 1696 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
if not return_dict:
|
| 1700 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1701 |
+
return ((loss,) + output) if loss is not None else output
|
| 1702 |
+
|
| 1703 |
+
return CausalLMOutput(
|
| 1704 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1705 |
+
)
|
| 1706 |
+
|
| 1707 |
+
|
| 1708 |
+
@add_start_docstrings(
|
| 1709 |
+
"""
|
| 1710 |
+
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
|
| 1711 |
+
tasks like SUPERB Keyword Spotting.
|
| 1712 |
+
""",
|
| 1713 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1714 |
+
)
|
| 1715 |
+
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1716 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1717 |
+
def __init__(self, config):
|
| 1718 |
+
super().__init__(config)
|
| 1719 |
+
|
| 1720 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1721 |
+
raise ValueError(
|
| 1722 |
+
"Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1723 |
+
)
|
| 1724 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1725 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1726 |
+
if config.use_weighted_layer_sum:
|
| 1727 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1728 |
+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1729 |
+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1730 |
+
|
| 1731 |
+
# Initialize weights and apply final processing
|
| 1732 |
+
self.post_init()
|
| 1733 |
+
|
| 1734 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1735 |
+
def freeze_feature_encoder(self):
|
| 1736 |
+
"""
|
| 1737 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1738 |
+
not be updated during training.
|
| 1739 |
+
"""
|
| 1740 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1741 |
+
|
| 1742 |
+
def freeze_base_model(self):
|
| 1743 |
+
"""
|
| 1744 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1745 |
+
be updated during training. Only the classification head will be updated.
|
| 1746 |
+
"""
|
| 1747 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 1748 |
+
param.requires_grad = False
|
| 1749 |
+
|
| 1750 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1751 |
+
@add_code_sample_docstrings(
|
| 1752 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1753 |
+
output_type=SequenceClassifierOutput,
|
| 1754 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1755 |
+
modality="audio",
|
| 1756 |
+
)
|
| 1757 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1758 |
+
def forward(
|
| 1759 |
+
self,
|
| 1760 |
+
input_values: Optional[torch.Tensor],
|
| 1761 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1762 |
+
output_attentions: Optional[bool] = None,
|
| 1763 |
+
output_hidden_states: Optional[bool] = None,
|
| 1764 |
+
return_dict: Optional[bool] = None,
|
| 1765 |
+
labels: Optional[torch.Tensor] = None,
|
| 1766 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1767 |
+
r"""
|
| 1768 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1769 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1770 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1771 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1772 |
+
"""
|
| 1773 |
+
|
| 1774 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1775 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1776 |
+
|
| 1777 |
+
outputs = self.wav2vec2_conformer(
|
| 1778 |
+
input_values,
|
| 1779 |
+
attention_mask=attention_mask,
|
| 1780 |
+
output_attentions=output_attentions,
|
| 1781 |
+
output_hidden_states=output_hidden_states,
|
| 1782 |
+
return_dict=return_dict,
|
| 1783 |
+
)
|
| 1784 |
+
|
| 1785 |
+
if self.config.use_weighted_layer_sum:
|
| 1786 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1787 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1788 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1789 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1790 |
+
else:
|
| 1791 |
+
hidden_states = outputs[0]
|
| 1792 |
+
|
| 1793 |
+
hidden_states = self.projector(hidden_states)
|
| 1794 |
+
if attention_mask is None:
|
| 1795 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 1796 |
+
else:
|
| 1797 |
+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1798 |
+
hidden_states[~padding_mask] = 0.0
|
| 1799 |
+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1800 |
+
|
| 1801 |
+
logits = self.classifier(pooled_output)
|
| 1802 |
+
|
| 1803 |
+
loss = None
|
| 1804 |
+
if labels is not None:
|
| 1805 |
+
loss_fct = CrossEntropyLoss()
|
| 1806 |
+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1807 |
+
|
| 1808 |
+
if not return_dict:
|
| 1809 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1810 |
+
return ((loss,) + output) if loss is not None else output
|
| 1811 |
+
|
| 1812 |
+
return SequenceClassifierOutput(
|
| 1813 |
+
loss=loss,
|
| 1814 |
+
logits=logits,
|
| 1815 |
+
hidden_states=outputs.hidden_states,
|
| 1816 |
+
attentions=outputs.attentions,
|
| 1817 |
+
)
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
@add_start_docstrings(
|
| 1821 |
+
"""
|
| 1822 |
+
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
|
| 1823 |
+
""",
|
| 1824 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1825 |
+
)
|
| 1826 |
+
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1827 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1828 |
+
def __init__(self, config):
|
| 1829 |
+
super().__init__(config)
|
| 1830 |
+
|
| 1831 |
+
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1832 |
+
raise ValueError(
|
| 1833 |
+
"Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1834 |
+
)
|
| 1835 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1836 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1837 |
+
if config.use_weighted_layer_sum:
|
| 1838 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1839 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1840 |
+
self.num_labels = config.num_labels
|
| 1841 |
+
|
| 1842 |
+
self.init_weights()
|
| 1843 |
+
|
| 1844 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1845 |
+
def freeze_feature_encoder(self):
|
| 1846 |
+
"""
|
| 1847 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1848 |
+
not be updated during training.
|
| 1849 |
+
"""
|
| 1850 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1851 |
+
|
| 1852 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 1853 |
+
def freeze_base_model(self):
|
| 1854 |
+
"""
|
| 1855 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1856 |
+
be updated during training. Only the classification head will be updated.
|
| 1857 |
+
"""
|
| 1858 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 1859 |
+
param.requires_grad = False
|
| 1860 |
+
|
| 1861 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1862 |
+
@add_code_sample_docstrings(
|
| 1863 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1864 |
+
output_type=TokenClassifierOutput,
|
| 1865 |
+
config_class=_CONFIG_FOR_DOC,
|
| 1866 |
+
modality="audio",
|
| 1867 |
+
)
|
| 1868 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
|
| 1869 |
+
def forward(
|
| 1870 |
+
self,
|
| 1871 |
+
input_values: Optional[torch.Tensor],
|
| 1872 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1873 |
+
labels: Optional[torch.Tensor] = None,
|
| 1874 |
+
output_attentions: Optional[bool] = None,
|
| 1875 |
+
output_hidden_states: Optional[bool] = None,
|
| 1876 |
+
return_dict: Optional[bool] = None,
|
| 1877 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1878 |
+
r"""
|
| 1879 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1880 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1881 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1882 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1883 |
+
"""
|
| 1884 |
+
|
| 1885 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1886 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1887 |
+
|
| 1888 |
+
outputs = self.wav2vec2_conformer(
|
| 1889 |
+
input_values,
|
| 1890 |
+
attention_mask=attention_mask,
|
| 1891 |
+
output_attentions=output_attentions,
|
| 1892 |
+
output_hidden_states=output_hidden_states,
|
| 1893 |
+
return_dict=return_dict,
|
| 1894 |
+
)
|
| 1895 |
+
|
| 1896 |
+
if self.config.use_weighted_layer_sum:
|
| 1897 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1898 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1899 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1900 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1901 |
+
else:
|
| 1902 |
+
hidden_states = outputs[0]
|
| 1903 |
+
|
| 1904 |
+
logits = self.classifier(hidden_states)
|
| 1905 |
+
|
| 1906 |
+
loss = None
|
| 1907 |
+
if labels is not None:
|
| 1908 |
+
loss_fct = CrossEntropyLoss()
|
| 1909 |
+
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
| 1910 |
+
|
| 1911 |
+
if not return_dict:
|
| 1912 |
+
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1913 |
+
return output
|
| 1914 |
+
|
| 1915 |
+
return TokenClassifierOutput(
|
| 1916 |
+
loss=loss,
|
| 1917 |
+
logits=logits,
|
| 1918 |
+
hidden_states=outputs.hidden_states,
|
| 1919 |
+
attentions=outputs.attentions,
|
| 1920 |
+
)
|
| 1921 |
+
|
| 1922 |
+
|
| 1923 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
| 1924 |
+
class AMSoftmaxLoss(nn.Module):
|
| 1925 |
+
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
| 1926 |
+
super(AMSoftmaxLoss, self).__init__()
|
| 1927 |
+
self.scale = scale
|
| 1928 |
+
self.margin = margin
|
| 1929 |
+
self.num_labels = num_labels
|
| 1930 |
+
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
| 1931 |
+
self.loss = nn.CrossEntropyLoss()
|
| 1932 |
+
|
| 1933 |
+
def forward(self, hidden_states, labels):
|
| 1934 |
+
labels = labels.flatten()
|
| 1935 |
+
weight = nn.functional.normalize(self.weight, dim=0)
|
| 1936 |
+
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
| 1937 |
+
cos_theta = torch.mm(hidden_states, weight)
|
| 1938 |
+
psi = cos_theta - self.margin
|
| 1939 |
+
|
| 1940 |
+
onehot = nn.functional.one_hot(labels, self.num_labels)
|
| 1941 |
+
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
| 1942 |
+
loss = self.loss(logits, labels)
|
| 1943 |
+
|
| 1944 |
+
return loss
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
| 1948 |
+
class TDNNLayer(nn.Module):
|
| 1949 |
+
def __init__(self, config, layer_id=0):
|
| 1950 |
+
super().__init__()
|
| 1951 |
+
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
| 1952 |
+
self.out_conv_dim = config.tdnn_dim[layer_id]
|
| 1953 |
+
self.kernel_size = config.tdnn_kernel[layer_id]
|
| 1954 |
+
self.dilation = config.tdnn_dilation[layer_id]
|
| 1955 |
+
|
| 1956 |
+
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
| 1957 |
+
self.activation = nn.ReLU()
|
| 1958 |
+
|
| 1959 |
+
def forward(self, hidden_states):
|
| 1960 |
+
hidden_states = hidden_states.unsqueeze(1)
|
| 1961 |
+
hidden_states = nn.functional.unfold(
|
| 1962 |
+
hidden_states,
|
| 1963 |
+
(self.kernel_size, self.in_conv_dim),
|
| 1964 |
+
stride=(1, self.in_conv_dim),
|
| 1965 |
+
dilation=(self.dilation, 1),
|
| 1966 |
+
)
|
| 1967 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1968 |
+
hidden_states = self.kernel(hidden_states)
|
| 1969 |
+
|
| 1970 |
+
hidden_states = self.activation(hidden_states)
|
| 1971 |
+
return hidden_states
|
| 1972 |
+
|
| 1973 |
+
|
| 1974 |
+
@add_start_docstrings(
|
| 1975 |
+
"""
|
| 1976 |
+
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
| 1977 |
+
""",
|
| 1978 |
+
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1979 |
+
)
|
| 1980 |
+
class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
|
| 1981 |
+
def __init__(self, config):
|
| 1982 |
+
super().__init__(config)
|
| 1983 |
+
|
| 1984 |
+
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1985 |
+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1986 |
+
if config.use_weighted_layer_sum:
|
| 1987 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1988 |
+
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
| 1989 |
+
|
| 1990 |
+
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
| 1991 |
+
self.tdnn = nn.ModuleList(tdnn_layers)
|
| 1992 |
+
|
| 1993 |
+
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
| 1994 |
+
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
| 1995 |
+
|
| 1996 |
+
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
| 1997 |
+
|
| 1998 |
+
self.init_weights()
|
| 1999 |
+
|
| 2000 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 2001 |
+
def freeze_feature_encoder(self):
|
| 2002 |
+
"""
|
| 2003 |
+
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 2004 |
+
not be updated during training.
|
| 2005 |
+
"""
|
| 2006 |
+
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 2007 |
+
|
| 2008 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 2009 |
+
def freeze_base_model(self):
|
| 2010 |
+
"""
|
| 2011 |
+
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 2012 |
+
be updated during training. Only the classification head will be updated.
|
| 2013 |
+
"""
|
| 2014 |
+
for param in self.wav2vec2_conformer.parameters():
|
| 2015 |
+
param.requires_grad = False
|
| 2016 |
+
|
| 2017 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
|
| 2018 |
+
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 2019 |
+
"""
|
| 2020 |
+
Computes the output length of the TDNN layers
|
| 2021 |
+
"""
|
| 2022 |
+
|
| 2023 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
| 2024 |
+
# 1D convolutional layer output length formula taken
|
| 2025 |
+
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 2026 |
+
return (input_length - kernel_size) // stride + 1
|
| 2027 |
+
|
| 2028 |
+
for kernel_size in self.config.tdnn_kernel:
|
| 2029 |
+
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
| 2030 |
+
|
| 2031 |
+
return input_lengths
|
| 2032 |
+
|
| 2033 |
+
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 2034 |
+
@add_code_sample_docstrings(
|
| 2035 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 2036 |
+
output_type=XVectorOutput,
|
| 2037 |
+
config_class=_CONFIG_FOR_DOC,
|
| 2038 |
+
modality="audio",
|
| 2039 |
+
)
|
| 2040 |
+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 2041 |
+
def forward(
|
| 2042 |
+
self,
|
| 2043 |
+
input_values: Optional[torch.Tensor],
|
| 2044 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2045 |
+
output_attentions: Optional[bool] = None,
|
| 2046 |
+
output_hidden_states: Optional[bool] = None,
|
| 2047 |
+
return_dict: Optional[bool] = None,
|
| 2048 |
+
labels: Optional[torch.Tensor] = None,
|
| 2049 |
+
) -> Union[Tuple, XVectorOutput]:
|
| 2050 |
+
r"""
|
| 2051 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 2052 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 2053 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 2054 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 2055 |
+
"""
|
| 2056 |
+
|
| 2057 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2058 |
+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 2059 |
+
|
| 2060 |
+
outputs = self.wav2vec2_conformer(
|
| 2061 |
+
input_values,
|
| 2062 |
+
attention_mask=attention_mask,
|
| 2063 |
+
output_attentions=output_attentions,
|
| 2064 |
+
output_hidden_states=output_hidden_states,
|
| 2065 |
+
return_dict=return_dict,
|
| 2066 |
+
)
|
| 2067 |
+
|
| 2068 |
+
if self.config.use_weighted_layer_sum:
|
| 2069 |
+
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 2070 |
+
hidden_states = torch.stack(hidden_states, dim=1)
|
| 2071 |
+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 2072 |
+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 2073 |
+
else:
|
| 2074 |
+
hidden_states = outputs[0]
|
| 2075 |
+
|
| 2076 |
+
hidden_states = self.projector(hidden_states)
|
| 2077 |
+
|
| 2078 |
+
for tdnn_layer in self.tdnn:
|
| 2079 |
+
hidden_states = tdnn_layer(hidden_states)
|
| 2080 |
+
|
| 2081 |
+
# Statistic Pooling
|
| 2082 |
+
if attention_mask is None:
|
| 2083 |
+
mean_features = hidden_states.mean(dim=1)
|
| 2084 |
+
std_features = hidden_states.std(dim=1)
|
| 2085 |
+
else:
|
| 2086 |
+
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
| 2087 |
+
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
| 2088 |
+
mean_features = []
|
| 2089 |
+
std_features = []
|
| 2090 |
+
for i, length in enumerate(tdnn_output_lengths):
|
| 2091 |
+
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
| 2092 |
+
std_features.append(hidden_states[i, :length].std(dim=0))
|
| 2093 |
+
mean_features = torch.stack(mean_features)
|
| 2094 |
+
std_features = torch.stack(std_features)
|
| 2095 |
+
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
| 2096 |
+
|
| 2097 |
+
output_embeddings = self.feature_extractor(statistic_pooling)
|
| 2098 |
+
logits = self.classifier(output_embeddings)
|
| 2099 |
+
|
| 2100 |
+
loss = None
|
| 2101 |
+
if labels is not None:
|
| 2102 |
+
loss = self.objective(logits, labels)
|
| 2103 |
+
|
| 2104 |
+
if not return_dict:
|
| 2105 |
+
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 2106 |
+
return ((loss,) + output) if loss is not None else output
|
| 2107 |
+
|
| 2108 |
+
return XVectorOutput(
|
| 2109 |
+
loss=loss,
|
| 2110 |
+
logits=logits,
|
| 2111 |
+
embeddings=output_embeddings,
|
| 2112 |
+
hidden_states=outputs.hidden_states,
|
| 2113 |
+
attentions=outputs.attentions,
|
| 2114 |
+
)
|
musicfm/modules/random_quantizer.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2023 ByteDance Inc.
|
| 4 |
+
#
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
|
| 6 |
+
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
| 7 |
+
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 8 |
+
#
|
| 9 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 10 |
+
#
|
| 11 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 12 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 13 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
| 14 |
+
# IN THE SOFTWARE.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn, einsum
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RandomProjectionQuantizer(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Random projection and codebook lookup module
|
| 24 |
+
|
| 25 |
+
Some code is borrowed from:
|
| 26 |
+
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
|
| 27 |
+
But I did normalization using pre-computed global mean & variance instead of using layer norm.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
input_dim,
|
| 33 |
+
codebook_dim,
|
| 34 |
+
codebook_size,
|
| 35 |
+
seed=142,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
# random seed
|
| 40 |
+
torch.manual_seed(seed)
|
| 41 |
+
|
| 42 |
+
# randomly initialized projection
|
| 43 |
+
random_projection = torch.empty(input_dim, codebook_dim)
|
| 44 |
+
nn.init.xavier_normal_(random_projection)
|
| 45 |
+
self.register_buffer("random_projection", random_projection)
|
| 46 |
+
|
| 47 |
+
# randomly initialized codebook
|
| 48 |
+
codebook = torch.empty(codebook_size, codebook_dim)
|
| 49 |
+
nn.init.normal_(codebook)
|
| 50 |
+
self.register_buffer("codebook", codebook)
|
| 51 |
+
|
| 52 |
+
def codebook_lookup(self, x):
|
| 53 |
+
# reshape
|
| 54 |
+
b = x.shape[0]
|
| 55 |
+
x = rearrange(x, "b n e -> (b n) e")
|
| 56 |
+
|
| 57 |
+
# L2 normalization
|
| 58 |
+
normalized_x = nn.functional.normalize(x, dim=1, p=2)
|
| 59 |
+
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
|
| 60 |
+
|
| 61 |
+
# compute distances
|
| 62 |
+
distances = torch.cdist(normalized_codebook, normalized_x)
|
| 63 |
+
|
| 64 |
+
# get nearest
|
| 65 |
+
nearest_indices = torch.argmin(distances, dim=0)
|
| 66 |
+
|
| 67 |
+
# reshape
|
| 68 |
+
xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
|
| 69 |
+
|
| 70 |
+
return xq
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
# always eval
|
| 75 |
+
self.eval()
|
| 76 |
+
|
| 77 |
+
# random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
|
| 78 |
+
x = einsum("b n d, d e -> b n e", x, self.random_projection)
|
| 79 |
+
|
| 80 |
+
# codebook lookup
|
| 81 |
+
xq = self.codebook_lookup(x)
|
| 82 |
+
|
| 83 |
+
return xq
|
postprocessing/functional.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file contains code adapted from the following sources:
|
| 2 |
+
# [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/functional.py
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from .helpers import (
|
| 7 |
+
local_maxima,
|
| 8 |
+
peak_picking,
|
| 9 |
+
# event_frames_to_time,
|
| 10 |
+
)
|
| 11 |
+
from dataset.label2id import LABEL_TO_ID, ID_TO_LABEL
|
| 12 |
+
from dataset.custom_types import MsaInfo
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def event_frames_to_time(frame_rates, boundary: np.array):
|
| 16 |
+
boundary = np.array(boundary)
|
| 17 |
+
boundary_times = boundary / frame_rates
|
| 18 |
+
return boundary_times
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def postprocess_functional_structure(
|
| 22 |
+
logits,
|
| 23 |
+
config,
|
| 24 |
+
):
|
| 25 |
+
# pdb.set_trace()
|
| 26 |
+
boundary_logits = logits["boundary_logits"]
|
| 27 |
+
function_logits = logits["function_logits"]
|
| 28 |
+
|
| 29 |
+
assert boundary_logits.shape[0] == 1 and function_logits.shape[0] == 1, (
|
| 30 |
+
"Only batch size 1 is supported"
|
| 31 |
+
)
|
| 32 |
+
raw_prob_sections = torch.sigmoid(boundary_logits[0])
|
| 33 |
+
raw_prob_functions = torch.softmax(function_logits[0].transpose(0, 1), dim=0)
|
| 34 |
+
|
| 35 |
+
# filter_size=4 * cfg.min_hops_per_beat + 1
|
| 36 |
+
prob_sections, _ = local_maxima(
|
| 37 |
+
raw_prob_sections, filter_size=config.local_maxima_filter_size
|
| 38 |
+
)
|
| 39 |
+
prob_sections = prob_sections.cpu().numpy()
|
| 40 |
+
|
| 41 |
+
prob_functions = raw_prob_functions.cpu().numpy()
|
| 42 |
+
|
| 43 |
+
boundary_candidates = peak_picking(
|
| 44 |
+
boundary_activation=prob_sections,
|
| 45 |
+
window_past=int(12 * config.frame_rates), # 原来是fps
|
| 46 |
+
window_future=int(12 * config.frame_rates),
|
| 47 |
+
)
|
| 48 |
+
boundary = boundary_candidates > 0.0
|
| 49 |
+
|
| 50 |
+
duration = len(prob_sections) / config.frame_rates
|
| 51 |
+
pred_boundary_times = event_frames_to_time(
|
| 52 |
+
frame_rates=config.frame_rates, boundary=np.flatnonzero(boundary)
|
| 53 |
+
)
|
| 54 |
+
if pred_boundary_times[0] != 0:
|
| 55 |
+
pred_boundary_times = np.insert(pred_boundary_times, 0, 0)
|
| 56 |
+
if pred_boundary_times[-1] != duration:
|
| 57 |
+
pred_boundary_times = np.append(pred_boundary_times, duration)
|
| 58 |
+
pred_boundaries = np.stack([pred_boundary_times[:-1], pred_boundary_times[1:]]).T
|
| 59 |
+
|
| 60 |
+
pred_boundary_indices = np.flatnonzero(boundary)
|
| 61 |
+
pred_boundary_indices = pred_boundary_indices[pred_boundary_indices > 0]
|
| 62 |
+
prob_segment_function = np.split(prob_functions, pred_boundary_indices, axis=1)
|
| 63 |
+
pred_labels = [p.mean(axis=1).argmax().item() for p in prob_segment_function]
|
| 64 |
+
|
| 65 |
+
segments: MsaInfo = []
|
| 66 |
+
for (start, end), label in zip(pred_boundaries, pred_labels):
|
| 67 |
+
segment = (float(start), str(ID_TO_LABEL[label]))
|
| 68 |
+
segments.append(segment)
|
| 69 |
+
|
| 70 |
+
segments.append((float(pred_boundary_times[-1]), "end"))
|
| 71 |
+
return segments
|
postprocessing/helpers.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file contains code adapted from the following sources:
|
| 2 |
+
# [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/helpers.py
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch
|
| 7 |
+
import librosa
|
| 8 |
+
from typing import Union
|
| 9 |
+
from scipy.signal import argrelextrema
|
| 10 |
+
from scipy.interpolate import interp1d
|
| 11 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
| 12 |
+
from numpy.typing import NDArray
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def local_maxima(tensor, filter_size=41):
|
| 16 |
+
assert len(tensor.shape) in (1, 2), "Input tensor should have 1 or 2 dimensions"
|
| 17 |
+
assert filter_size % 2 == 1, "Filter size should be an odd number"
|
| 18 |
+
|
| 19 |
+
original_shape = tensor.shape
|
| 20 |
+
if len(original_shape) == 1:
|
| 21 |
+
tensor = tensor.unsqueeze(0)
|
| 22 |
+
|
| 23 |
+
# Pad the input array with the minimum value
|
| 24 |
+
padding = filter_size // 2
|
| 25 |
+
padded_arr = F.pad(tensor, (padding, padding), mode="constant", value=-torch.inf)
|
| 26 |
+
|
| 27 |
+
# Create a rolling window view of the padded array
|
| 28 |
+
rolling_view = padded_arr.unfold(1, filter_size, 1)
|
| 29 |
+
|
| 30 |
+
# Find the indices of the local maxima
|
| 31 |
+
center = filter_size // 2
|
| 32 |
+
local_maxima_mask = torch.eq(
|
| 33 |
+
rolling_view[:, :, center], torch.max(rolling_view, dim=-1).values
|
| 34 |
+
)
|
| 35 |
+
local_maxima_indices = local_maxima_mask.nonzero()
|
| 36 |
+
|
| 37 |
+
# Initialize a new PyTorch tensor with zeros and the same shape as the input tensor
|
| 38 |
+
output_arr = torch.zeros_like(tensor)
|
| 39 |
+
|
| 40 |
+
# Set the local maxima values in the output tensor
|
| 41 |
+
output_arr[local_maxima_mask] = tensor[local_maxima_mask]
|
| 42 |
+
|
| 43 |
+
output_arr = output_arr.reshape(original_shape)
|
| 44 |
+
|
| 45 |
+
return output_arr, local_maxima_indices
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def local_maxima_numpy(arr, order=20):
|
| 49 |
+
is_batch = len(arr.shape) == 2
|
| 50 |
+
if is_batch:
|
| 51 |
+
return np.stack([local_maxima_numpy(x, order) for x in arr])
|
| 52 |
+
|
| 53 |
+
# Define a comparison function for argrelextrema to find local maxima
|
| 54 |
+
compare_func = np.greater
|
| 55 |
+
|
| 56 |
+
# Find the indices of the local maxima
|
| 57 |
+
local_maxima_indices = argrelextrema(arr, compare_func, order=order)
|
| 58 |
+
|
| 59 |
+
# Initialize a new numpy array with zeros and the same shape as the input array
|
| 60 |
+
output_arr = np.zeros_like(arr)
|
| 61 |
+
|
| 62 |
+
# Set the local maxima values in the output array
|
| 63 |
+
output_arr[local_maxima_indices] = arr[local_maxima_indices]
|
| 64 |
+
|
| 65 |
+
return output_arr
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def peak_picking(boundary_activation, window_past=12, window_future=6):
|
| 69 |
+
# Find local maxima using a sliding window
|
| 70 |
+
window_size = window_past + window_future
|
| 71 |
+
assert window_size % 2 == 0, "window_past + window_future must be even"
|
| 72 |
+
window_size += 1
|
| 73 |
+
|
| 74 |
+
# Pad boundary_activation
|
| 75 |
+
boundary_activation_padded = np.pad(
|
| 76 |
+
boundary_activation, (window_past, window_future), mode="constant"
|
| 77 |
+
)
|
| 78 |
+
max_filter = sliding_window_view(boundary_activation_padded, window_size)
|
| 79 |
+
local_maxima = (boundary_activation == np.max(max_filter, axis=-1)) & (
|
| 80 |
+
boundary_activation > 0
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Compute strength values by subtracting the mean of the past and future windows
|
| 84 |
+
past_window_filter = sliding_window_view(
|
| 85 |
+
boundary_activation_padded[: -(window_future + 1)], window_past
|
| 86 |
+
)
|
| 87 |
+
future_window_filter = sliding_window_view(
|
| 88 |
+
boundary_activation_padded[window_past + 1 :], window_future
|
| 89 |
+
)
|
| 90 |
+
past_mean = np.mean(past_window_filter, axis=-1)
|
| 91 |
+
future_mean = np.mean(future_window_filter, axis=-1)
|
| 92 |
+
strength_values = boundary_activation - ((past_mean + future_mean) / 2)
|
| 93 |
+
|
| 94 |
+
# Get boundary candidates and their corresponding strength values
|
| 95 |
+
boundary_candidates = np.flatnonzero(local_maxima)
|
| 96 |
+
strength_values = strength_values[boundary_candidates]
|
| 97 |
+
|
| 98 |
+
strength_activations = np.zeros_like(boundary_activation)
|
| 99 |
+
strength_activations[boundary_candidates] = strength_values
|
| 100 |
+
|
| 101 |
+
return strength_activations
|