saurabhati commited on
Commit
4e65175
·
verified ·
1 Parent(s): d28474f

Upload DASSForAudioClassification

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +1089 -0
  3. configuration_dass.py +91 -0
  4. model.safetensors +3 -0
  5. modeling_dass.py +1196 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DASSForAudioClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_dass.DASSConfig",
7
+ "AutoModelForAudioClassification": "modeling_dass.DASSForAudioClassification"
8
+ },
9
+ "depths": [
10
+ 2,
11
+ 2,
12
+ 8,
13
+ 2
14
+ ],
15
+ "dims": [
16
+ 96,
17
+ 192,
18
+ 384,
19
+ 768
20
+ ],
21
+ "drop_path_rate": 0.2,
22
+ "embed_dim": 96,
23
+ "id2label": {
24
+ "0": "Speech",
25
+ "1": "Male speech, man speaking",
26
+ "2": "Female speech, woman speaking",
27
+ "3": "Child speech, kid speaking",
28
+ "4": "Conversation",
29
+ "5": "Narration, monologue",
30
+ "6": "Babbling",
31
+ "7": "Speech synthesizer",
32
+ "8": "Shout",
33
+ "9": "Bellow",
34
+ "10": "Whoop",
35
+ "11": "Yell",
36
+ "12": "Battle cry",
37
+ "13": "Children shouting",
38
+ "14": "Screaming",
39
+ "15": "Whispering",
40
+ "16": "Laughter",
41
+ "17": "Baby laughter",
42
+ "18": "Giggle",
43
+ "19": "Snicker",
44
+ "20": "Belly laugh",
45
+ "21": "Chuckle, chortle",
46
+ "22": "Crying, sobbing",
47
+ "23": "Baby cry, infant cry",
48
+ "24": "Whimper",
49
+ "25": "Wail, moan",
50
+ "26": "Sigh",
51
+ "27": "Singing",
52
+ "28": "Choir",
53
+ "29": "Yodeling",
54
+ "30": "Chant",
55
+ "31": "Mantra",
56
+ "32": "Male singing",
57
+ "33": "Female singing",
58
+ "34": "Child singing",
59
+ "35": "Synthetic singing",
60
+ "36": "Rapping",
61
+ "37": "Humming",
62
+ "38": "Groan",
63
+ "39": "Grunt",
64
+ "40": "Whistling",
65
+ "41": "Breathing",
66
+ "42": "Wheeze",
67
+ "43": "Snoring",
68
+ "44": "Gasp",
69
+ "45": "Pant",
70
+ "46": "Snort",
71
+ "47": "Cough",
72
+ "48": "Throat clearing",
73
+ "49": "Sneeze",
74
+ "50": "Sniff",
75
+ "51": "Run",
76
+ "52": "Shuffle",
77
+ "53": "Walk, footsteps",
78
+ "54": "Chewing, mastication",
79
+ "55": "Biting",
80
+ "56": "Gargling",
81
+ "57": "Stomach rumble",
82
+ "58": "Burping, eructation",
83
+ "59": "Hiccup",
84
+ "60": "Fart",
85
+ "61": "Hands",
86
+ "62": "Finger snapping",
87
+ "63": "Clapping",
88
+ "64": "Heart sounds, heartbeat",
89
+ "65": "Heart murmur",
90
+ "66": "Cheering",
91
+ "67": "Applause",
92
+ "68": "Chatter",
93
+ "69": "Crowd",
94
+ "70": "Hubbub, speech noise, speech babble",
95
+ "71": "Children playing",
96
+ "72": "Animal",
97
+ "73": "Domestic animals, pets",
98
+ "74": "Dog",
99
+ "75": "Bark",
100
+ "76": "Yip",
101
+ "77": "Howl",
102
+ "78": "Bow-wow",
103
+ "79": "Growling",
104
+ "80": "Whimper (dog)",
105
+ "81": "Cat",
106
+ "82": "Purr",
107
+ "83": "Meow",
108
+ "84": "Hiss",
109
+ "85": "Caterwaul",
110
+ "86": "Livestock, farm animals, working animals",
111
+ "87": "Horse",
112
+ "88": "Clip-clop",
113
+ "89": "Neigh, whinny",
114
+ "90": "Cattle, bovinae",
115
+ "91": "Moo",
116
+ "92": "Cowbell",
117
+ "93": "Pig",
118
+ "94": "Oink",
119
+ "95": "Goat",
120
+ "96": "Bleat",
121
+ "97": "Sheep",
122
+ "98": "Fowl",
123
+ "99": "Chicken, rooster",
124
+ "100": "Cluck",
125
+ "101": "Crowing, cock-a-doodle-doo",
126
+ "102": "Turkey",
127
+ "103": "Gobble",
128
+ "104": "Duck",
129
+ "105": "Quack",
130
+ "106": "Goose",
131
+ "107": "Honk",
132
+ "108": "Wild animals",
133
+ "109": "Roaring cats (lions, tigers)",
134
+ "110": "Roar",
135
+ "111": "Bird",
136
+ "112": "Bird vocalization, bird call, bird song",
137
+ "113": "Chirp, tweet",
138
+ "114": "Squawk",
139
+ "115": "Pigeon, dove",
140
+ "116": "Coo",
141
+ "117": "Crow",
142
+ "118": "Caw",
143
+ "119": "Owl",
144
+ "120": "Hoot",
145
+ "121": "Bird flight, flapping wings",
146
+ "122": "Canidae, dogs, wolves",
147
+ "123": "Rodents, rats, mice",
148
+ "124": "Mouse",
149
+ "125": "Patter",
150
+ "126": "Insect",
151
+ "127": "Cricket",
152
+ "128": "Mosquito",
153
+ "129": "Fly, housefly",
154
+ "130": "Buzz",
155
+ "131": "Bee, wasp, etc.",
156
+ "132": "Frog",
157
+ "133": "Croak",
158
+ "134": "Snake",
159
+ "135": "Rattle",
160
+ "136": "Whale vocalization",
161
+ "137": "Music",
162
+ "138": "Musical instrument",
163
+ "139": "Plucked string instrument",
164
+ "140": "Guitar",
165
+ "141": "Electric guitar",
166
+ "142": "Bass guitar",
167
+ "143": "Acoustic guitar",
168
+ "144": "Steel guitar, slide guitar",
169
+ "145": "Tapping (guitar technique)",
170
+ "146": "Strum",
171
+ "147": "Banjo",
172
+ "148": "Sitar",
173
+ "149": "Mandolin",
174
+ "150": "Zither",
175
+ "151": "Ukulele",
176
+ "152": "Keyboard (musical)",
177
+ "153": "Piano",
178
+ "154": "Electric piano",
179
+ "155": "Organ",
180
+ "156": "Electronic organ",
181
+ "157": "Hammond organ",
182
+ "158": "Synthesizer",
183
+ "159": "Sampler",
184
+ "160": "Harpsichord",
185
+ "161": "Percussion",
186
+ "162": "Drum kit",
187
+ "163": "Drum machine",
188
+ "164": "Drum",
189
+ "165": "Snare drum",
190
+ "166": "Rimshot",
191
+ "167": "Drum roll",
192
+ "168": "Bass drum",
193
+ "169": "Timpani",
194
+ "170": "Tabla",
195
+ "171": "Cymbal",
196
+ "172": "Hi-hat",
197
+ "173": "Wood block",
198
+ "174": "Tambourine",
199
+ "175": "Rattle (instrument)",
200
+ "176": "Maraca",
201
+ "177": "Gong",
202
+ "178": "Tubular bells",
203
+ "179": "Mallet percussion",
204
+ "180": "Marimba, xylophone",
205
+ "181": "Glockenspiel",
206
+ "182": "Vibraphone",
207
+ "183": "Steelpan",
208
+ "184": "Orchestra",
209
+ "185": "Brass instrument",
210
+ "186": "French horn",
211
+ "187": "Trumpet",
212
+ "188": "Trombone",
213
+ "189": "Bowed string instrument",
214
+ "190": "String section",
215
+ "191": "Violin, fiddle",
216
+ "192": "Pizzicato",
217
+ "193": "Cello",
218
+ "194": "Double bass",
219
+ "195": "Wind instrument, woodwind instrument",
220
+ "196": "Flute",
221
+ "197": "Saxophone",
222
+ "198": "Clarinet",
223
+ "199": "Harp",
224
+ "200": "Bell",
225
+ "201": "Church bell",
226
+ "202": "Jingle bell",
227
+ "203": "Bicycle bell",
228
+ "204": "Tuning fork",
229
+ "205": "Chime",
230
+ "206": "Wind chime",
231
+ "207": "Change ringing (campanology)",
232
+ "208": "Harmonica",
233
+ "209": "Accordion",
234
+ "210": "Bagpipes",
235
+ "211": "Didgeridoo",
236
+ "212": "Shofar",
237
+ "213": "Theremin",
238
+ "214": "Singing bowl",
239
+ "215": "Scratching (performance technique)",
240
+ "216": "Pop music",
241
+ "217": "Hip hop music",
242
+ "218": "Beatboxing",
243
+ "219": "Rock music",
244
+ "220": "Heavy metal",
245
+ "221": "Punk rock",
246
+ "222": "Grunge",
247
+ "223": "Progressive rock",
248
+ "224": "Rock and roll",
249
+ "225": "Psychedelic rock",
250
+ "226": "Rhythm and blues",
251
+ "227": "Soul music",
252
+ "228": "Reggae",
253
+ "229": "Country",
254
+ "230": "Swing music",
255
+ "231": "Bluegrass",
256
+ "232": "Funk",
257
+ "233": "Folk music",
258
+ "234": "Middle Eastern music",
259
+ "235": "Jazz",
260
+ "236": "Disco",
261
+ "237": "Classical music",
262
+ "238": "Opera",
263
+ "239": "Electronic music",
264
+ "240": "House music",
265
+ "241": "Techno",
266
+ "242": "Dubstep",
267
+ "243": "Drum and bass",
268
+ "244": "Electronica",
269
+ "245": "Electronic dance music",
270
+ "246": "Ambient music",
271
+ "247": "Trance music",
272
+ "248": "Music of Latin America",
273
+ "249": "Salsa music",
274
+ "250": "Flamenco",
275
+ "251": "Blues",
276
+ "252": "Music for children",
277
+ "253": "New-age music",
278
+ "254": "Vocal music",
279
+ "255": "A capella",
280
+ "256": "Music of Africa",
281
+ "257": "Afrobeat",
282
+ "258": "Christian music",
283
+ "259": "Gospel music",
284
+ "260": "Music of Asia",
285
+ "261": "Carnatic music",
286
+ "262": "Music of Bollywood",
287
+ "263": "Ska",
288
+ "264": "Traditional music",
289
+ "265": "Independent music",
290
+ "266": "Song",
291
+ "267": "Background music",
292
+ "268": "Theme music",
293
+ "269": "Jingle (music)",
294
+ "270": "Soundtrack music",
295
+ "271": "Lullaby",
296
+ "272": "Video game music",
297
+ "273": "Christmas music",
298
+ "274": "Dance music",
299
+ "275": "Wedding music",
300
+ "276": "Happy music",
301
+ "277": "Funny music",
302
+ "278": "Sad music",
303
+ "279": "Tender music",
304
+ "280": "Exciting music",
305
+ "281": "Angry music",
306
+ "282": "Scary music",
307
+ "283": "Wind",
308
+ "284": "Rustling leaves",
309
+ "285": "Wind noise (microphone)",
310
+ "286": "Thunderstorm",
311
+ "287": "Thunder",
312
+ "288": "Water",
313
+ "289": "Rain",
314
+ "290": "Raindrop",
315
+ "291": "Rain on surface",
316
+ "292": "Stream",
317
+ "293": "Waterfall",
318
+ "294": "Ocean",
319
+ "295": "Waves, surf",
320
+ "296": "Steam",
321
+ "297": "Gurgling",
322
+ "298": "Fire",
323
+ "299": "Crackle",
324
+ "300": "Vehicle",
325
+ "301": "Boat, Water vehicle",
326
+ "302": "Sailboat, sailing ship",
327
+ "303": "Rowboat, canoe, kayak",
328
+ "304": "Motorboat, speedboat",
329
+ "305": "Ship",
330
+ "306": "Motor vehicle (road)",
331
+ "307": "Car",
332
+ "308": "Vehicle horn, car horn, honking",
333
+ "309": "Toot",
334
+ "310": "Car alarm",
335
+ "311": "Power windows, electric windows",
336
+ "312": "Skidding",
337
+ "313": "Tire squeal",
338
+ "314": "Car passing by",
339
+ "315": "Race car, auto racing",
340
+ "316": "Truck",
341
+ "317": "Air brake",
342
+ "318": "Air horn, truck horn",
343
+ "319": "Reversing beeps",
344
+ "320": "Ice cream truck, ice cream van",
345
+ "321": "Bus",
346
+ "322": "Emergency vehicle",
347
+ "323": "Police car (siren)",
348
+ "324": "Ambulance (siren)",
349
+ "325": "Fire engine, fire truck (siren)",
350
+ "326": "Motorcycle",
351
+ "327": "Traffic noise, roadway noise",
352
+ "328": "Rail transport",
353
+ "329": "Train",
354
+ "330": "Train whistle",
355
+ "331": "Train horn",
356
+ "332": "Railroad car, train wagon",
357
+ "333": "Train wheels squealing",
358
+ "334": "Subway, metro, underground",
359
+ "335": "Aircraft",
360
+ "336": "Aircraft engine",
361
+ "337": "Jet engine",
362
+ "338": "Propeller, airscrew",
363
+ "339": "Helicopter",
364
+ "340": "Fixed-wing aircraft, airplane",
365
+ "341": "Bicycle",
366
+ "342": "Skateboard",
367
+ "343": "Engine",
368
+ "344": "Light engine (high frequency)",
369
+ "345": "Dental drill, dentist's drill",
370
+ "346": "Lawn mower",
371
+ "347": "Chainsaw",
372
+ "348": "Medium engine (mid frequency)",
373
+ "349": "Heavy engine (low frequency)",
374
+ "350": "Engine knocking",
375
+ "351": "Engine starting",
376
+ "352": "Idling",
377
+ "353": "Accelerating, revving, vroom",
378
+ "354": "Door",
379
+ "355": "Doorbell",
380
+ "356": "Ding-dong",
381
+ "357": "Sliding door",
382
+ "358": "Slam",
383
+ "359": "Knock",
384
+ "360": "Tap",
385
+ "361": "Squeak",
386
+ "362": "Cupboard open or close",
387
+ "363": "Drawer open or close",
388
+ "364": "Dishes, pots, and pans",
389
+ "365": "Cutlery, silverware",
390
+ "366": "Chopping (food)",
391
+ "367": "Frying (food)",
392
+ "368": "Microwave oven",
393
+ "369": "Blender",
394
+ "370": "Water tap, faucet",
395
+ "371": "Sink (filling or washing)",
396
+ "372": "Bathtub (filling or washing)",
397
+ "373": "Hair dryer",
398
+ "374": "Toilet flush",
399
+ "375": "Toothbrush",
400
+ "376": "Electric toothbrush",
401
+ "377": "Vacuum cleaner",
402
+ "378": "Zipper (clothing)",
403
+ "379": "Keys jangling",
404
+ "380": "Coin (dropping)",
405
+ "381": "Scissors",
406
+ "382": "Electric shaver, electric razor",
407
+ "383": "Shuffling cards",
408
+ "384": "Typing",
409
+ "385": "Typewriter",
410
+ "386": "Computer keyboard",
411
+ "387": "Writing",
412
+ "388": "Alarm",
413
+ "389": "Telephone",
414
+ "390": "Telephone bell ringing",
415
+ "391": "Ringtone",
416
+ "392": "Telephone dialing, DTMF",
417
+ "393": "Dial tone",
418
+ "394": "Busy signal",
419
+ "395": "Alarm clock",
420
+ "396": "Siren",
421
+ "397": "Civil defense siren",
422
+ "398": "Buzzer",
423
+ "399": "Smoke detector, smoke alarm",
424
+ "400": "Fire alarm",
425
+ "401": "Foghorn",
426
+ "402": "Whistle",
427
+ "403": "Steam whistle",
428
+ "404": "Mechanisms",
429
+ "405": "Ratchet, pawl",
430
+ "406": "Clock",
431
+ "407": "Tick",
432
+ "408": "Tick-tock",
433
+ "409": "Gears",
434
+ "410": "Pulleys",
435
+ "411": "Sewing machine",
436
+ "412": "Mechanical fan",
437
+ "413": "Air conditioning",
438
+ "414": "Cash register",
439
+ "415": "Printer",
440
+ "416": "Camera",
441
+ "417": "Single-lens reflex camera",
442
+ "418": "Tools",
443
+ "419": "Hammer",
444
+ "420": "Jackhammer",
445
+ "421": "Sawing",
446
+ "422": "Filing (rasp)",
447
+ "423": "Sanding",
448
+ "424": "Power tool",
449
+ "425": "Drill",
450
+ "426": "Explosion",
451
+ "427": "Gunshot, gunfire",
452
+ "428": "Machine gun",
453
+ "429": "Fusillade",
454
+ "430": "Artillery fire",
455
+ "431": "Cap gun",
456
+ "432": "Fireworks",
457
+ "433": "Firecracker",
458
+ "434": "Burst, pop",
459
+ "435": "Eruption",
460
+ "436": "Boom",
461
+ "437": "Wood",
462
+ "438": "Chop",
463
+ "439": "Splinter",
464
+ "440": "Crack",
465
+ "441": "Glass",
466
+ "442": "Chink, clink",
467
+ "443": "Shatter",
468
+ "444": "Liquid",
469
+ "445": "Splash, splatter",
470
+ "446": "Slosh",
471
+ "447": "Squish",
472
+ "448": "Drip",
473
+ "449": "Pour",
474
+ "450": "Trickle, dribble",
475
+ "451": "Gush",
476
+ "452": "Fill (with liquid)",
477
+ "453": "Spray",
478
+ "454": "Pump (liquid)",
479
+ "455": "Stir",
480
+ "456": "Boiling",
481
+ "457": "Sonar",
482
+ "458": "Arrow",
483
+ "459": "Whoosh, swoosh, swish",
484
+ "460": "Thump, thud",
485
+ "461": "Thunk",
486
+ "462": "Electronic tuner",
487
+ "463": "Effects unit",
488
+ "464": "Chorus effect",
489
+ "465": "Basketball bounce",
490
+ "466": "Bang",
491
+ "467": "Slap, smack",
492
+ "468": "Whack, thwack",
493
+ "469": "Smash, crash",
494
+ "470": "Breaking",
495
+ "471": "Bouncing",
496
+ "472": "Whip",
497
+ "473": "Flap",
498
+ "474": "Scratch",
499
+ "475": "Scrape",
500
+ "476": "Rub",
501
+ "477": "Roll",
502
+ "478": "Crushing",
503
+ "479": "Crumpling, crinkling",
504
+ "480": "Tearing",
505
+ "481": "Beep, bleep",
506
+ "482": "Ping",
507
+ "483": "Ding",
508
+ "484": "Clang",
509
+ "485": "Squeal",
510
+ "486": "Creak",
511
+ "487": "Rustle",
512
+ "488": "Whir",
513
+ "489": "Clatter",
514
+ "490": "Sizzle",
515
+ "491": "Clicking",
516
+ "492": "Clickety-clack",
517
+ "493": "Rumble",
518
+ "494": "Plop",
519
+ "495": "Jingle, tinkle",
520
+ "496": "Hum",
521
+ "497": "Zing",
522
+ "498": "Boing",
523
+ "499": "Crunch",
524
+ "500": "Silence",
525
+ "501": "Sine wave",
526
+ "502": "Harmonic",
527
+ "503": "Chirp tone",
528
+ "504": "Sound effect",
529
+ "505": "Pulse",
530
+ "506": "Inside, small room",
531
+ "507": "Inside, large room or hall",
532
+ "508": "Inside, public space",
533
+ "509": "Outside, urban or manmade",
534
+ "510": "Outside, rural or natural",
535
+ "511": "Reverberation",
536
+ "512": "Echo",
537
+ "513": "Noise",
538
+ "514": "Environmental noise",
539
+ "515": "Static",
540
+ "516": "Mains hum",
541
+ "517": "Distortion",
542
+ "518": "Sidetone",
543
+ "519": "Cacophony",
544
+ "520": "White noise",
545
+ "521": "Pink noise",
546
+ "522": "Throbbing",
547
+ "523": "Vibration",
548
+ "524": "Television",
549
+ "525": "Radio",
550
+ "526": "Field recording"
551
+ },
552
+ "label2id": {
553
+ "A capella": 255,
554
+ "Accelerating, revving, vroom": 353,
555
+ "Accordion": 209,
556
+ "Acoustic guitar": 143,
557
+ "Afrobeat": 257,
558
+ "Air brake": 317,
559
+ "Air conditioning": 413,
560
+ "Air horn, truck horn": 318,
561
+ "Aircraft": 335,
562
+ "Aircraft engine": 336,
563
+ "Alarm": 388,
564
+ "Alarm clock": 395,
565
+ "Ambient music": 246,
566
+ "Ambulance (siren)": 324,
567
+ "Angry music": 281,
568
+ "Animal": 72,
569
+ "Applause": 67,
570
+ "Arrow": 458,
571
+ "Artillery fire": 430,
572
+ "Babbling": 6,
573
+ "Baby cry, infant cry": 23,
574
+ "Baby laughter": 17,
575
+ "Background music": 267,
576
+ "Bagpipes": 210,
577
+ "Bang": 466,
578
+ "Banjo": 147,
579
+ "Bark": 75,
580
+ "Basketball bounce": 465,
581
+ "Bass drum": 168,
582
+ "Bass guitar": 142,
583
+ "Bathtub (filling or washing)": 372,
584
+ "Battle cry": 12,
585
+ "Beatboxing": 218,
586
+ "Bee, wasp, etc.": 131,
587
+ "Beep, bleep": 481,
588
+ "Bell": 200,
589
+ "Bellow": 9,
590
+ "Belly laugh": 20,
591
+ "Bicycle": 341,
592
+ "Bicycle bell": 203,
593
+ "Bird": 111,
594
+ "Bird flight, flapping wings": 121,
595
+ "Bird vocalization, bird call, bird song": 112,
596
+ "Biting": 55,
597
+ "Bleat": 96,
598
+ "Blender": 369,
599
+ "Bluegrass": 231,
600
+ "Blues": 251,
601
+ "Boat, Water vehicle": 301,
602
+ "Boiling": 456,
603
+ "Boing": 498,
604
+ "Boom": 436,
605
+ "Bouncing": 471,
606
+ "Bow-wow": 78,
607
+ "Bowed string instrument": 189,
608
+ "Brass instrument": 185,
609
+ "Breaking": 470,
610
+ "Breathing": 41,
611
+ "Burping, eructation": 58,
612
+ "Burst, pop": 434,
613
+ "Bus": 321,
614
+ "Busy signal": 394,
615
+ "Buzz": 130,
616
+ "Buzzer": 398,
617
+ "Cacophony": 519,
618
+ "Camera": 416,
619
+ "Canidae, dogs, wolves": 122,
620
+ "Cap gun": 431,
621
+ "Car": 307,
622
+ "Car alarm": 310,
623
+ "Car passing by": 314,
624
+ "Carnatic music": 261,
625
+ "Cash register": 414,
626
+ "Cat": 81,
627
+ "Caterwaul": 85,
628
+ "Cattle, bovinae": 90,
629
+ "Caw": 118,
630
+ "Cello": 193,
631
+ "Chainsaw": 347,
632
+ "Change ringing (campanology)": 207,
633
+ "Chant": 30,
634
+ "Chatter": 68,
635
+ "Cheering": 66,
636
+ "Chewing, mastication": 54,
637
+ "Chicken, rooster": 99,
638
+ "Child singing": 34,
639
+ "Child speech, kid speaking": 3,
640
+ "Children playing": 71,
641
+ "Children shouting": 13,
642
+ "Chime": 205,
643
+ "Chink, clink": 442,
644
+ "Chirp tone": 503,
645
+ "Chirp, tweet": 113,
646
+ "Choir": 28,
647
+ "Chop": 438,
648
+ "Chopping (food)": 366,
649
+ "Chorus effect": 464,
650
+ "Christian music": 258,
651
+ "Christmas music": 273,
652
+ "Chuckle, chortle": 21,
653
+ "Church bell": 201,
654
+ "Civil defense siren": 397,
655
+ "Clang": 484,
656
+ "Clapping": 63,
657
+ "Clarinet": 198,
658
+ "Classical music": 237,
659
+ "Clatter": 489,
660
+ "Clickety-clack": 492,
661
+ "Clicking": 491,
662
+ "Clip-clop": 88,
663
+ "Clock": 406,
664
+ "Cluck": 100,
665
+ "Coin (dropping)": 380,
666
+ "Computer keyboard": 386,
667
+ "Conversation": 4,
668
+ "Coo": 116,
669
+ "Cough": 47,
670
+ "Country": 229,
671
+ "Cowbell": 92,
672
+ "Crack": 440,
673
+ "Crackle": 299,
674
+ "Creak": 486,
675
+ "Cricket": 127,
676
+ "Croak": 133,
677
+ "Crow": 117,
678
+ "Crowd": 69,
679
+ "Crowing, cock-a-doodle-doo": 101,
680
+ "Crumpling, crinkling": 479,
681
+ "Crunch": 499,
682
+ "Crushing": 478,
683
+ "Crying, sobbing": 22,
684
+ "Cupboard open or close": 362,
685
+ "Cutlery, silverware": 365,
686
+ "Cymbal": 171,
687
+ "Dance music": 274,
688
+ "Dental drill, dentist's drill": 345,
689
+ "Dial tone": 393,
690
+ "Didgeridoo": 211,
691
+ "Ding": 483,
692
+ "Ding-dong": 356,
693
+ "Disco": 236,
694
+ "Dishes, pots, and pans": 364,
695
+ "Distortion": 517,
696
+ "Dog": 74,
697
+ "Domestic animals, pets": 73,
698
+ "Door": 354,
699
+ "Doorbell": 355,
700
+ "Double bass": 194,
701
+ "Drawer open or close": 363,
702
+ "Drill": 425,
703
+ "Drip": 448,
704
+ "Drum": 164,
705
+ "Drum and bass": 243,
706
+ "Drum kit": 162,
707
+ "Drum machine": 163,
708
+ "Drum roll": 167,
709
+ "Dubstep": 242,
710
+ "Duck": 104,
711
+ "Echo": 512,
712
+ "Effects unit": 463,
713
+ "Electric guitar": 141,
714
+ "Electric piano": 154,
715
+ "Electric shaver, electric razor": 382,
716
+ "Electric toothbrush": 376,
717
+ "Electronic dance music": 245,
718
+ "Electronic music": 239,
719
+ "Electronic organ": 156,
720
+ "Electronic tuner": 462,
721
+ "Electronica": 244,
722
+ "Emergency vehicle": 322,
723
+ "Engine": 343,
724
+ "Engine knocking": 350,
725
+ "Engine starting": 351,
726
+ "Environmental noise": 514,
727
+ "Eruption": 435,
728
+ "Exciting music": 280,
729
+ "Explosion": 426,
730
+ "Fart": 60,
731
+ "Female singing": 33,
732
+ "Female speech, woman speaking": 2,
733
+ "Field recording": 526,
734
+ "Filing (rasp)": 422,
735
+ "Fill (with liquid)": 452,
736
+ "Finger snapping": 62,
737
+ "Fire": 298,
738
+ "Fire alarm": 400,
739
+ "Fire engine, fire truck (siren)": 325,
740
+ "Firecracker": 433,
741
+ "Fireworks": 432,
742
+ "Fixed-wing aircraft, airplane": 340,
743
+ "Flamenco": 250,
744
+ "Flap": 473,
745
+ "Flute": 196,
746
+ "Fly, housefly": 129,
747
+ "Foghorn": 401,
748
+ "Folk music": 233,
749
+ "Fowl": 98,
750
+ "French horn": 186,
751
+ "Frog": 132,
752
+ "Frying (food)": 367,
753
+ "Funk": 232,
754
+ "Funny music": 277,
755
+ "Fusillade": 429,
756
+ "Gargling": 56,
757
+ "Gasp": 44,
758
+ "Gears": 409,
759
+ "Giggle": 18,
760
+ "Glass": 441,
761
+ "Glockenspiel": 181,
762
+ "Goat": 95,
763
+ "Gobble": 103,
764
+ "Gong": 177,
765
+ "Goose": 106,
766
+ "Gospel music": 259,
767
+ "Groan": 38,
768
+ "Growling": 79,
769
+ "Grunge": 222,
770
+ "Grunt": 39,
771
+ "Guitar": 140,
772
+ "Gunshot, gunfire": 427,
773
+ "Gurgling": 297,
774
+ "Gush": 451,
775
+ "Hair dryer": 373,
776
+ "Hammer": 419,
777
+ "Hammond organ": 157,
778
+ "Hands": 61,
779
+ "Happy music": 276,
780
+ "Harmonic": 502,
781
+ "Harmonica": 208,
782
+ "Harp": 199,
783
+ "Harpsichord": 160,
784
+ "Heart murmur": 65,
785
+ "Heart sounds, heartbeat": 64,
786
+ "Heavy engine (low frequency)": 349,
787
+ "Heavy metal": 220,
788
+ "Helicopter": 339,
789
+ "Hi-hat": 172,
790
+ "Hiccup": 59,
791
+ "Hip hop music": 217,
792
+ "Hiss": 84,
793
+ "Honk": 107,
794
+ "Hoot": 120,
795
+ "Horse": 87,
796
+ "House music": 240,
797
+ "Howl": 77,
798
+ "Hubbub, speech noise, speech babble": 70,
799
+ "Hum": 496,
800
+ "Humming": 37,
801
+ "Ice cream truck, ice cream van": 320,
802
+ "Idling": 352,
803
+ "Independent music": 265,
804
+ "Insect": 126,
805
+ "Inside, large room or hall": 507,
806
+ "Inside, public space": 508,
807
+ "Inside, small room": 506,
808
+ "Jackhammer": 420,
809
+ "Jazz": 235,
810
+ "Jet engine": 337,
811
+ "Jingle (music)": 269,
812
+ "Jingle bell": 202,
813
+ "Jingle, tinkle": 495,
814
+ "Keyboard (musical)": 152,
815
+ "Keys jangling": 379,
816
+ "Knock": 359,
817
+ "Laughter": 16,
818
+ "Lawn mower": 346,
819
+ "Light engine (high frequency)": 344,
820
+ "Liquid": 444,
821
+ "Livestock, farm animals, working animals": 86,
822
+ "Lullaby": 271,
823
+ "Machine gun": 428,
824
+ "Mains hum": 516,
825
+ "Male singing": 32,
826
+ "Male speech, man speaking": 1,
827
+ "Mallet percussion": 179,
828
+ "Mandolin": 149,
829
+ "Mantra": 31,
830
+ "Maraca": 176,
831
+ "Marimba, xylophone": 180,
832
+ "Mechanical fan": 412,
833
+ "Mechanisms": 404,
834
+ "Medium engine (mid frequency)": 348,
835
+ "Meow": 83,
836
+ "Microwave oven": 368,
837
+ "Middle Eastern music": 234,
838
+ "Moo": 91,
839
+ "Mosquito": 128,
840
+ "Motor vehicle (road)": 306,
841
+ "Motorboat, speedboat": 304,
842
+ "Motorcycle": 326,
843
+ "Mouse": 124,
844
+ "Music": 137,
845
+ "Music for children": 252,
846
+ "Music of Africa": 256,
847
+ "Music of Asia": 260,
848
+ "Music of Bollywood": 262,
849
+ "Music of Latin America": 248,
850
+ "Musical instrument": 138,
851
+ "Narration, monologue": 5,
852
+ "Neigh, whinny": 89,
853
+ "New-age music": 253,
854
+ "Noise": 513,
855
+ "Ocean": 294,
856
+ "Oink": 94,
857
+ "Opera": 238,
858
+ "Orchestra": 184,
859
+ "Organ": 155,
860
+ "Outside, rural or natural": 510,
861
+ "Outside, urban or manmade": 509,
862
+ "Owl": 119,
863
+ "Pant": 45,
864
+ "Patter": 125,
865
+ "Percussion": 161,
866
+ "Piano": 153,
867
+ "Pig": 93,
868
+ "Pigeon, dove": 115,
869
+ "Ping": 482,
870
+ "Pink noise": 521,
871
+ "Pizzicato": 192,
872
+ "Plop": 494,
873
+ "Plucked string instrument": 139,
874
+ "Police car (siren)": 323,
875
+ "Pop music": 216,
876
+ "Pour": 449,
877
+ "Power tool": 424,
878
+ "Power windows, electric windows": 311,
879
+ "Printer": 415,
880
+ "Progressive rock": 223,
881
+ "Propeller, airscrew": 338,
882
+ "Psychedelic rock": 225,
883
+ "Pulleys": 410,
884
+ "Pulse": 505,
885
+ "Pump (liquid)": 454,
886
+ "Punk rock": 221,
887
+ "Purr": 82,
888
+ "Quack": 105,
889
+ "Race car, auto racing": 315,
890
+ "Radio": 525,
891
+ "Rail transport": 328,
892
+ "Railroad car, train wagon": 332,
893
+ "Rain": 289,
894
+ "Rain on surface": 291,
895
+ "Raindrop": 290,
896
+ "Rapping": 36,
897
+ "Ratchet, pawl": 405,
898
+ "Rattle": 135,
899
+ "Rattle (instrument)": 175,
900
+ "Reggae": 228,
901
+ "Reverberation": 511,
902
+ "Reversing beeps": 319,
903
+ "Rhythm and blues": 226,
904
+ "Rimshot": 166,
905
+ "Ringtone": 391,
906
+ "Roar": 110,
907
+ "Roaring cats (lions, tigers)": 109,
908
+ "Rock and roll": 224,
909
+ "Rock music": 219,
910
+ "Rodents, rats, mice": 123,
911
+ "Roll": 477,
912
+ "Rowboat, canoe, kayak": 303,
913
+ "Rub": 476,
914
+ "Rumble": 493,
915
+ "Run": 51,
916
+ "Rustle": 487,
917
+ "Rustling leaves": 284,
918
+ "Sad music": 278,
919
+ "Sailboat, sailing ship": 302,
920
+ "Salsa music": 249,
921
+ "Sampler": 159,
922
+ "Sanding": 423,
923
+ "Sawing": 421,
924
+ "Saxophone": 197,
925
+ "Scary music": 282,
926
+ "Scissors": 381,
927
+ "Scrape": 475,
928
+ "Scratch": 474,
929
+ "Scratching (performance technique)": 215,
930
+ "Screaming": 14,
931
+ "Sewing machine": 411,
932
+ "Shatter": 443,
933
+ "Sheep": 97,
934
+ "Ship": 305,
935
+ "Shofar": 212,
936
+ "Shout": 8,
937
+ "Shuffle": 52,
938
+ "Shuffling cards": 383,
939
+ "Sidetone": 518,
940
+ "Sigh": 26,
941
+ "Silence": 500,
942
+ "Sine wave": 501,
943
+ "Singing": 27,
944
+ "Singing bowl": 214,
945
+ "Single-lens reflex camera": 417,
946
+ "Sink (filling or washing)": 371,
947
+ "Siren": 396,
948
+ "Sitar": 148,
949
+ "Sizzle": 490,
950
+ "Ska": 263,
951
+ "Skateboard": 342,
952
+ "Skidding": 312,
953
+ "Slam": 358,
954
+ "Slap, smack": 467,
955
+ "Sliding door": 357,
956
+ "Slosh": 446,
957
+ "Smash, crash": 469,
958
+ "Smoke detector, smoke alarm": 399,
959
+ "Snake": 134,
960
+ "Snare drum": 165,
961
+ "Sneeze": 49,
962
+ "Snicker": 19,
963
+ "Sniff": 50,
964
+ "Snoring": 43,
965
+ "Snort": 46,
966
+ "Sonar": 457,
967
+ "Song": 266,
968
+ "Soul music": 227,
969
+ "Sound effect": 504,
970
+ "Soundtrack music": 270,
971
+ "Speech": 0,
972
+ "Speech synthesizer": 7,
973
+ "Splash, splatter": 445,
974
+ "Splinter": 439,
975
+ "Spray": 453,
976
+ "Squawk": 114,
977
+ "Squeak": 361,
978
+ "Squeal": 485,
979
+ "Squish": 447,
980
+ "Static": 515,
981
+ "Steam": 296,
982
+ "Steam whistle": 403,
983
+ "Steel guitar, slide guitar": 144,
984
+ "Steelpan": 183,
985
+ "Stir": 455,
986
+ "Stomach rumble": 57,
987
+ "Stream": 292,
988
+ "String section": 190,
989
+ "Strum": 146,
990
+ "Subway, metro, underground": 334,
991
+ "Swing music": 230,
992
+ "Synthesizer": 158,
993
+ "Synthetic singing": 35,
994
+ "Tabla": 170,
995
+ "Tambourine": 174,
996
+ "Tap": 360,
997
+ "Tapping (guitar technique)": 145,
998
+ "Tearing": 480,
999
+ "Techno": 241,
1000
+ "Telephone": 389,
1001
+ "Telephone bell ringing": 390,
1002
+ "Telephone dialing, DTMF": 392,
1003
+ "Television": 524,
1004
+ "Tender music": 279,
1005
+ "Theme music": 268,
1006
+ "Theremin": 213,
1007
+ "Throat clearing": 48,
1008
+ "Throbbing": 522,
1009
+ "Thump, thud": 460,
1010
+ "Thunder": 287,
1011
+ "Thunderstorm": 286,
1012
+ "Thunk": 461,
1013
+ "Tick": 407,
1014
+ "Tick-tock": 408,
1015
+ "Timpani": 169,
1016
+ "Tire squeal": 313,
1017
+ "Toilet flush": 374,
1018
+ "Tools": 418,
1019
+ "Toot": 309,
1020
+ "Toothbrush": 375,
1021
+ "Traditional music": 264,
1022
+ "Traffic noise, roadway noise": 327,
1023
+ "Train": 329,
1024
+ "Train horn": 331,
1025
+ "Train wheels squealing": 333,
1026
+ "Train whistle": 330,
1027
+ "Trance music": 247,
1028
+ "Trickle, dribble": 450,
1029
+ "Trombone": 188,
1030
+ "Truck": 316,
1031
+ "Trumpet": 187,
1032
+ "Tubular bells": 178,
1033
+ "Tuning fork": 204,
1034
+ "Turkey": 102,
1035
+ "Typewriter": 385,
1036
+ "Typing": 384,
1037
+ "Ukulele": 151,
1038
+ "Vacuum cleaner": 377,
1039
+ "Vehicle": 300,
1040
+ "Vehicle horn, car horn, honking": 308,
1041
+ "Vibraphone": 182,
1042
+ "Vibration": 523,
1043
+ "Video game music": 272,
1044
+ "Violin, fiddle": 191,
1045
+ "Vocal music": 254,
1046
+ "Wail, moan": 25,
1047
+ "Walk, footsteps": 53,
1048
+ "Water": 288,
1049
+ "Water tap, faucet": 370,
1050
+ "Waterfall": 293,
1051
+ "Waves, surf": 295,
1052
+ "Wedding music": 275,
1053
+ "Whack, thwack": 468,
1054
+ "Whale vocalization": 136,
1055
+ "Wheeze": 42,
1056
+ "Whimper": 24,
1057
+ "Whimper (dog)": 80,
1058
+ "Whip": 472,
1059
+ "Whir": 488,
1060
+ "Whispering": 15,
1061
+ "Whistle": 402,
1062
+ "Whistling": 40,
1063
+ "White noise": 520,
1064
+ "Whoop": 10,
1065
+ "Whoosh, swoosh, swish": 459,
1066
+ "Wild animals": 108,
1067
+ "Wind": 283,
1068
+ "Wind chime": 206,
1069
+ "Wind instrument, woodwind instrument": 195,
1070
+ "Wind noise (microphone)": 285,
1071
+ "Wood": 437,
1072
+ "Wood block": 173,
1073
+ "Writing": 387,
1074
+ "Yell": 11,
1075
+ "Yip": 76,
1076
+ "Yodeling": 29,
1077
+ "Zing": 497,
1078
+ "Zipper (clothing)": 378,
1079
+ "Zither": 150
1080
+ },
1081
+ "max_length": 1024,
1082
+ "model_type": "dass",
1083
+ "num_classes": 527,
1084
+ "num_mel_bins": 128,
1085
+ "patch_size": 4,
1086
+ "torch_dtype": "float32",
1087
+ "transformers_version": "4.50.0.dev0",
1088
+ "use_checkpoint": false
1089
+ }
configuration_dass.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Distilled Audio State-Space Model (DASS) configuration"""
3
+
4
+ from typing import Any, Dict
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+ class DASSConfig(PretrainedConfig):
13
+ r"""
14
+ This is the configuration class to store the configuration of a [`DASSModel`]. It is used to instantiate a DASS
15
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
16
+ defaults will yield a similar configuration to that of the
17
+ [DASS-small](https://github.com/Saurabhbhati/DASS/) architecture.
18
+
19
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
20
+ documentation from [`PretrainedConfig`] for more information.
21
+
22
+ Args:
23
+ patch_size (`int`, *optional*, defaults to 4):
24
+ The size (resolution) of each patch.
25
+ embed_dim (`int`, *optional*, defaults to 96):
26
+ Dimensionality of patch embedding.
27
+ depths (`list(int)`, *optional*, defaults to `[2, 2, 8, 2]`):
28
+ Depth of each layer in the DASS encoder.
29
+ dims (`list(int)`, *optional*, defaults to `[96, 192, 384, 768]`):
30
+ Dimensionality of each layer in the DASS encoder.
31
+ drop_path_rate (`float`, *optional*, defaults to 0.2):
32
+ Stochastic depth rate.
33
+ num_classes (`int`, *optional*, defaults to 527):
34
+ Number of classes for classification.
35
+ max_length (`int`, *optional*, defaults to 1024):
36
+ Temporal dimension of the spectrograms.
37
+ num_mel_bins (`int`, *optional*, defaults to 128):
38
+ Frequency dimension of the spectrograms (number of Mel-frequency bins).
39
+ use_checkpoint (`bool`, *optional*, defaults to `False`):
40
+ Whether to use checkpointing to save memory.
41
+
42
+ Example:
43
+
44
+ ```python
45
+ >>> from transformers import DASSConfig, DASSModel
46
+
47
+ >>> # Initializing a DASS small style configuration
48
+ >>> configuration = DASSConfig()
49
+
50
+ >>> # Initializing a model (with random weights) from the DASS small style configuration
51
+ >>> model = DASSModel(configuration)
52
+
53
+ >>> # Accessing the model configuration
54
+ >>> configuration = model.config
55
+ ```"""
56
+
57
+ model_type = "dass"
58
+
59
+ def __init__(
60
+ self,
61
+ patch_size: int = 4,
62
+ embed_dim: int = 96,
63
+ depths: list = [2, 2, 8, 2],
64
+ dims: list =[96, 192, 384, 768],
65
+ drop_path_rate: float = 0.2,
66
+ num_classes: int = 527,
67
+ max_length: int = 1024,
68
+ num_mel_bins: int = 128,
69
+ use_checkpoint: bool = False,
70
+ **kwargs,
71
+ ):
72
+ super().__init__(**kwargs)
73
+
74
+ self.patch_size = patch_size
75
+ self.embed_dim = embed_dim
76
+ self.depths = depths
77
+ self.dims = dims
78
+ self.drop_path_rate = drop_path_rate
79
+ self.num_classes = num_classes
80
+ self.max_length = max_length
81
+ self.num_mel_bins = num_mel_bins
82
+ self.use_checkpoint = use_checkpoint
83
+
84
+ # Overwritten from the parent class: DASS is not compatible with `generate`, but has a config parameter sharing the
85
+ # same name (`max_length`). Sharing the same name triggers checks regarding the config -> generation_config
86
+ # generative parameters deprecation cycle, overwriting this function prevents this from happening.
87
+ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
88
+ return {}
89
+
90
+
91
+ __all__ = ["DASSConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d77e5315cc7d3df349ed18f7c7715c557e70080c88ca0fb5ed109b50c496f7d
3
+ size 119566972
modeling_dass.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # VMamba backbone is from https://github.com/MzeroMiko/VMamba/blob/main/vmamba.py
3
+ # DASSLayer, DASSModel, DASSForAudioClassification are implemnted based on VMamba and AST
4
+ #
5
+ """Distilled Audio State-Space Model (DASS) model"""
6
+
7
+ import math
8
+ import torch
9
+ import warnings
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+ from timm.models.layers import DropPath, trunc_normal_
14
+ from functools import partial
15
+ from typing import Optional, Callable, Any, Union
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
17
+ from transformers.modeling_outputs import SequenceClassifierOutput
18
+
19
+ from transformers.utils import logging
20
+ from transformers.modeling_utils import PreTrainedModel
21
+
22
+ from .configuration_dass import DASSConfig
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ # General docstring
27
+ _CONFIG_FOR_DOC = "DASSConfig"
28
+
29
+ WITH_TRITON = True
30
+ # WITH_TRITON = False
31
+ try:
32
+ import triton
33
+ import triton.language as tl
34
+ except:
35
+ WITH_TRITON = False
36
+ warnings.warn("Triton not installed, fall back to pytorch implements.")
37
+
38
+ # to make sure cached_property can be loaded for triton
39
+ if WITH_TRITON:
40
+ try:
41
+ from functools import cached_property
42
+ except:
43
+ warnings.warn("if you are using py37, add this line to functools.py: "
44
+ "cached_property = lambda func: property(lru_cache()(func))")
45
+
46
+ # torch implementation ========================================
47
+ def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
48
+ if in_channel_first:
49
+ B, C, H, W = x.shape
50
+ if scans == 0:
51
+ y = x.new_empty((B, 4, C, H * W))
52
+ y[:, 0, :, :] = x.flatten(2, 3)
53
+ y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
54
+ y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
55
+ elif scans == 1:
56
+ y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
57
+ elif scans == 2:
58
+ y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
59
+ y = torch.cat([y, y.flip(dims=[-1])], dim=1)
60
+ elif scans == 3:
61
+ y = x.new_empty((B, 4, C, H * W))
62
+ y[:, 0, :, :] = x.flatten(2, 3)
63
+ y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3)
64
+ y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3)
65
+ y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3)
66
+ else:
67
+ B, H, W, C = x.shape
68
+ if scans == 0:
69
+ y = x.new_empty((B, H * W, 4, C))
70
+ y[:, :, 0, :] = x.flatten(1, 2)
71
+ y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
72
+ y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
73
+ elif scans == 1:
74
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
75
+ elif scans == 2:
76
+ y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
77
+ y = torch.cat([y, y.flip(dims=[1])], dim=2)
78
+ elif scans == 3:
79
+ y = x.new_empty((B, H * W, 4, C))
80
+ y[:, :, 0, :] = x.flatten(1, 2)
81
+ y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2)
82
+ y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2)
83
+ y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2)
84
+
85
+ if in_channel_first and (not out_channel_first):
86
+ y = y.permute(0, 3, 1, 2).contiguous()
87
+ elif (not in_channel_first) and out_channel_first:
88
+ y = y.permute(0, 2, 3, 1).contiguous()
89
+
90
+ return y
91
+
92
+
93
+ def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
94
+ if out_channel_first:
95
+ B, K, D, H, W = y.shape
96
+ y = y.view(B, K, D, -1)
97
+ if scans == 0:
98
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
99
+ y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
100
+ elif scans == 1:
101
+ y = y.sum(1)
102
+ elif scans == 2:
103
+ y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
104
+ y = y.sum(1)
105
+ elif scans == 3:
106
+ oy = y[:, 0, :, :].contiguous().view(B, D, -1)
107
+ oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3)
108
+ oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3)
109
+ oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3)
110
+ y = oy
111
+ else:
112
+ B, H, W, K, D = y.shape
113
+ y = y.view(B, -1, K, D)
114
+ if scans == 0:
115
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
116
+ y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
117
+ elif scans == 1:
118
+ y = y.sum(2)
119
+ elif scans == 2:
120
+ y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
121
+ y = y.sum(2)
122
+ elif scans == 3:
123
+ oy = y[:, :, 0, :].contiguous().view(B, -1, D)
124
+ oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2)
125
+ oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2)
126
+ oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2)
127
+ y = oy
128
+
129
+ if in_channel_first and (not out_channel_first):
130
+ y = y.permute(0, 2, 1).contiguous()
131
+ elif (not in_channel_first) and out_channel_first:
132
+ y = y.permute(0, 2, 1).contiguous()
133
+
134
+ return y
135
+
136
+
137
+ def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
138
+ if in_channel_first:
139
+ B, _, C, H, W = x.shape
140
+ if scans == 0:
141
+ y = torch.stack([
142
+ x[:, 0].flatten(2, 3),
143
+ x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
144
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
145
+ torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
146
+ ], dim=1)
147
+ elif scans == 1:
148
+ y = x.flatten(2, 3)
149
+ elif scans == 2:
150
+ y = torch.stack([
151
+ x[:, 0].flatten(2, 3),
152
+ x[:, 1].flatten(2, 3),
153
+ torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
154
+ torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
155
+ ], dim=1)
156
+ elif scans == 3:
157
+ y = torch.stack([
158
+ x[:, 0, :, :, :].flatten(2, 3),
159
+ torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3),
160
+ torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3),
161
+ torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3),
162
+ ], dim=1)
163
+
164
+ else:
165
+ B, H, W, _, C = x.shape
166
+ if scans == 0:
167
+ y = torch.stack([
168
+ x[:, :, :, 0].flatten(1, 2),
169
+ x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
170
+ torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
171
+ torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
172
+ ], dim=2)
173
+ elif scans == 1:
174
+ y = x.flatten(1, 2)
175
+ elif scans == 2:
176
+ y = torch.stack([
177
+ x[:, 0].flatten(1, 2),
178
+ x[:, 1].flatten(1, 2),
179
+ torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
180
+ torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
181
+ ], dim=2)
182
+ elif scans == 3:
183
+ y = torch.stack([
184
+ x[:, :, :, 0, :].flatten(1, 2),
185
+ torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2),
186
+ torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2),
187
+ torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2),
188
+ ], dim=1)
189
+
190
+ if in_channel_first and (not out_channel_first):
191
+ y = y.permute(0, 3, 1, 2).contiguous()
192
+ elif (not in_channel_first) and out_channel_first:
193
+ y = y.permute(0, 2, 3, 1).contiguous()
194
+
195
+ return y
196
+
197
+
198
+ def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
199
+ if out_channel_first:
200
+ B, K, D, H, W = y.shape
201
+ y = y.view(B, K, D, -1)
202
+ if scans == 0:
203
+ y = torch.stack([
204
+ y[:, 0],
205
+ y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
206
+ torch.flip(y[:, 2], dims=[-1]),
207
+ torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
208
+ ], dim=1)
209
+ elif scans == 1:
210
+ y = y
211
+ elif scans == 2:
212
+ y = torch.stack([
213
+ y[:, 0],
214
+ y[:, 1],
215
+ torch.flip(y[:, 2], dims=[-1]),
216
+ torch.flip(y[:, 3], dims=[-1]),
217
+ ], dim=1)
218
+ elif scans == 3:
219
+ y = torch.stack([
220
+ y[:, 0, :, :].contiguous().view(B, D, -1),
221
+ torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3),
222
+ torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3),
223
+ torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3),
224
+ ], dim=1)
225
+ else:
226
+ B, H, W, K, D = y.shape
227
+ y = y.view(B, -1, K, D)
228
+ if scans == 0:
229
+ y = torch.stack([
230
+ y[:, :, 0],
231
+ y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
232
+ torch.flip(y[:, :, 2], dims=[1]),
233
+ torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
234
+ ], dim=2)
235
+ elif scans == 1:
236
+ y = y
237
+ elif scans == 2:
238
+ y = torch.stack([
239
+ y[:, :, 0],
240
+ y[:, :, 1],
241
+ torch.flip(y[:, :, 2], dims=[1]),
242
+ torch.flip(y[:, :, 3], dims=[1]),
243
+ ], dim=2)
244
+ elif scans == 3:
245
+ y = torch.stack([
246
+ y[:, :, 0, :].contiguous().view(B, -1, D),
247
+ torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2),
248
+ torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2),
249
+ torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2),
250
+ ], dim=2)
251
+
252
+ if out_channel_first and (not in_channel_first):
253
+ y = y.permute(0, 3, 1, 2).contiguous()
254
+ elif (not out_channel_first) and in_channel_first:
255
+ y = y.permute(0, 2, 3, 1).contiguous()
256
+
257
+ return y
258
+
259
+
260
+ class CrossScanF(torch.autograd.Function):
261
+ @staticmethod
262
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
263
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
264
+ # y: (B, 4, C, H * W) | (B, H * W, 4, C)
265
+ ctx.in_channel_first = in_channel_first
266
+ ctx.out_channel_first = out_channel_first
267
+ ctx.one_by_one = one_by_one
268
+ ctx.scans = scans
269
+
270
+ if one_by_one:
271
+ B, K, C, H, W = x.shape
272
+ if not in_channel_first:
273
+ B, H, W, K, C = x.shape
274
+ else:
275
+ B, C, H, W = x.shape
276
+ if not in_channel_first:
277
+ B, H, W, C = x.shape
278
+ ctx.shape = (B, C, H, W)
279
+
280
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
281
+ y = _fn(x, in_channel_first, out_channel_first, scans)
282
+
283
+ return y
284
+
285
+ @staticmethod
286
+ def backward(ctx, ys: torch.Tensor):
287
+ # out: (b, k, d, l)
288
+ in_channel_first = ctx.in_channel_first
289
+ out_channel_first = ctx.out_channel_first
290
+ one_by_one = ctx.one_by_one
291
+ scans = ctx.scans
292
+ B, C, H, W = ctx.shape
293
+
294
+ ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
295
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
296
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
297
+
298
+ if one_by_one:
299
+ y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
300
+ else:
301
+ y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
302
+
303
+ return y, None, None, None, None
304
+
305
+
306
+ class CrossMergeF(torch.autograd.Function):
307
+ @staticmethod
308
+ def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
309
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
310
+ # y: (B, 4, C, H * W) | (B, H * W, 4, C)
311
+ ctx.in_channel_first = in_channel_first
312
+ ctx.out_channel_first = out_channel_first
313
+ ctx.one_by_one = one_by_one
314
+ ctx.scans = scans
315
+
316
+ B, K, C, H, W = ys.shape
317
+ if not out_channel_first:
318
+ B, H, W, K, C = ys.shape
319
+ ctx.shape = (B, C, H, W)
320
+
321
+ _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
322
+ y = _fn(ys, in_channel_first, out_channel_first, scans)
323
+
324
+ return y
325
+
326
+ @staticmethod
327
+ def backward(ctx, x: torch.Tensor):
328
+ # B, D, L = x.shape
329
+ # out: (b, k, d, h, w)
330
+ in_channel_first = ctx.in_channel_first
331
+ out_channel_first = ctx.out_channel_first
332
+ one_by_one = ctx.one_by_one
333
+ scans = ctx.scans
334
+ B, C, H, W = ctx.shape
335
+
336
+ if not one_by_one:
337
+ if in_channel_first:
338
+ x = x.view(B, C, H, W)
339
+ else:
340
+ x = x.view(B, H, W, C)
341
+ else:
342
+ if in_channel_first:
343
+ x = x.view(B, 4, C, H, W)
344
+ else:
345
+ x = x.view(B, H, W, 4, C)
346
+
347
+ _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
348
+ x = _fn(x, in_channel_first, out_channel_first, scans)
349
+ x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
350
+
351
+ return x, None, None, None, None
352
+
353
+
354
+ # triton implements ========================================
355
+
356
+ @triton.jit
357
+ def triton_cross_scan_flex(
358
+ x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
359
+ y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
360
+ x_layout: tl.constexpr,
361
+ y_layout: tl.constexpr,
362
+ operation: tl.constexpr,
363
+ onebyone: tl.constexpr,
364
+ scans: tl.constexpr,
365
+ BC: tl.constexpr,
366
+ BH: tl.constexpr,
367
+ BW: tl.constexpr,
368
+ DC: tl.constexpr,
369
+ DH: tl.constexpr,
370
+ DW: tl.constexpr,
371
+ NH: tl.constexpr,
372
+ NW: tl.constexpr,
373
+ ):
374
+ # x_layout = 0
375
+ # y_layout = 1 # 0 BCHW, 1 BHWC
376
+ # operation = 0 # 0 scan, 1 merge
377
+ # onebyone = 0 # 0 false, 1 true
378
+ # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
379
+
380
+ i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
381
+ i_h, i_w = (i_hw // NW), (i_hw % NW)
382
+ _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
383
+ _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
384
+ _mask_hw = _mask_h[:, None] & _mask_w[None, :]
385
+ _for_C = min(DC - i_c * BC, BC)
386
+
387
+ pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
388
+ pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
389
+ neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
390
+ neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
391
+ if scans == 0:
392
+ # none; trans; flip; trans + flip;
393
+ HWRoute0 = pos_h * DW + pos_w
394
+ HWRoute1 = pos_w * DH + pos_h # trans
395
+ HWRoute2 = neg_h * DW + neg_w # flip
396
+ HWRoute3 = neg_w * DH + neg_h # trans + flip
397
+ elif scans == 1:
398
+ # none; none; none; none;
399
+ HWRoute0 = pos_h * DW + pos_w
400
+ HWRoute1 = HWRoute0
401
+ HWRoute2 = HWRoute0
402
+ HWRoute3 = HWRoute0
403
+ elif scans == 2:
404
+ # none; none; flip; flip;
405
+ HWRoute0 = pos_h * DW + pos_w
406
+ HWRoute1 = HWRoute0
407
+ HWRoute2 = neg_h * DW + neg_w # flip
408
+ HWRoute3 = HWRoute2
409
+ elif scans == 3:
410
+ # none; rot90; rot180==flip; rot270;
411
+ HWRoute0 = pos_h * DW + pos_w
412
+ HWRoute1 = neg_w * DH + pos_h
413
+ HWRoute2 = neg_h * DW + neg_w
414
+ HWRoute3 = pos_w * DH + neg_h
415
+
416
+ _tmp1 = DC * DH * DW
417
+
418
+ y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
419
+ if y_layout == 0:
420
+ p_y1 = y_ptr_base + HWRoute0
421
+ p_y2 = y_ptr_base + _tmp1 + HWRoute1
422
+ p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
423
+ p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
424
+ else:
425
+ p_y1 = y_ptr_base + HWRoute0 * 4 * DC
426
+ p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
427
+ p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
428
+ p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
429
+
430
+ if onebyone == 0:
431
+ x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
432
+ if x_layout == 0:
433
+ p_x = x_ptr_base + HWRoute0
434
+ else:
435
+ p_x = x_ptr_base + HWRoute0 * DC
436
+
437
+ if operation == 0:
438
+ for idxc in range(_for_C):
439
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
440
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
441
+ _x = tl.load(p_x + _idx_x, mask=_mask_hw)
442
+ tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
443
+ tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
444
+ tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
445
+ tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
446
+ elif operation == 1:
447
+ for idxc in range(_for_C):
448
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
449
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
450
+ _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
451
+ _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
452
+ _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
453
+ _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
454
+ tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
455
+
456
+ else:
457
+ x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
458
+ if x_layout == 0:
459
+ p_x1 = x_ptr_base + HWRoute0
460
+ p_x2 = p_x1 + _tmp1
461
+ p_x3 = p_x2 + _tmp1
462
+ p_x4 = p_x3 + _tmp1
463
+ else:
464
+ p_x1 = x_ptr_base + HWRoute0 * 4 * DC
465
+ p_x2 = p_x1 + DC
466
+ p_x3 = p_x2 + DC
467
+ p_x4 = p_x3 + DC
468
+
469
+ if operation == 0:
470
+ for idxc in range(_for_C):
471
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
472
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
473
+ tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
474
+ tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
475
+ tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
476
+ tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
477
+ else:
478
+ for idxc in range(_for_C):
479
+ _idx_x = idxc * DH * DW if x_layout == 0 else idxc
480
+ _idx_y = idxc * DH * DW if y_layout == 0 else idxc
481
+ tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
482
+ tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
483
+ tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
484
+ tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
485
+
486
+
487
+ class CrossScanTritonF(torch.autograd.Function):
488
+ @staticmethod
489
+ def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
490
+ if one_by_one:
491
+ if in_channel_first:
492
+ B, _, C, H, W = x.shape
493
+ else:
494
+ B, H, W, _, C = x.shape
495
+ else:
496
+ if in_channel_first:
497
+ B, C, H, W = x.shape
498
+ else:
499
+ B, H, W, C = x.shape
500
+ B, C, H, W = int(B), int(C), int(H), int(W)
501
+ BC, BH, BW = 1, 32, 32
502
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
503
+
504
+ ctx.in_channel_first = in_channel_first
505
+ ctx.out_channel_first = out_channel_first
506
+ ctx.one_by_one = one_by_one
507
+ ctx.scans = scans
508
+ ctx.shape = (B, C, H, W)
509
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
510
+
511
+ y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
512
+ triton_cross_scan_flex[(NH * NW, NC, B)](
513
+ x.contiguous(), y,
514
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
515
+ BC, BH, BW, C, H, W, NH, NW
516
+ )
517
+ return y
518
+
519
+ @staticmethod
520
+ def backward(ctx, y: torch.Tensor):
521
+ in_channel_first = ctx.in_channel_first
522
+ out_channel_first = ctx.out_channel_first
523
+ one_by_one = ctx.one_by_one
524
+ scans = ctx.scans
525
+ B, C, H, W = ctx.shape
526
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
527
+ if one_by_one:
528
+ x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
529
+ else:
530
+ x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
531
+
532
+ triton_cross_scan_flex[(NH * NW, NC, B)](
533
+ x, y.contiguous(),
534
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
535
+ BC, BH, BW, C, H, W, NH, NW
536
+ )
537
+ return x, None, None, None, None
538
+
539
+
540
+ class CrossMergeTritonF(torch.autograd.Function):
541
+ @staticmethod
542
+ def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
543
+ if out_channel_first:
544
+ B, _, C, H, W = y.shape
545
+ else:
546
+ B, H, W, _, C = y.shape
547
+ B, C, H, W = int(B), int(C), int(H), int(W)
548
+ BC, BH, BW = 1, 32, 32
549
+ NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
550
+ ctx.in_channel_first = in_channel_first
551
+ ctx.out_channel_first = out_channel_first
552
+ ctx.one_by_one = one_by_one
553
+ ctx.scans = scans
554
+ ctx.shape = (B, C, H, W)
555
+ ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
556
+ if one_by_one:
557
+ x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
558
+ else:
559
+ x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
560
+ triton_cross_scan_flex[(NH * NW, NC, B)](
561
+ x, y.contiguous(),
562
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
563
+ BC, BH, BW, C, H, W, NH, NW
564
+ )
565
+ return x
566
+
567
+ @staticmethod
568
+ def backward(ctx, x: torch.Tensor):
569
+ in_channel_first = ctx.in_channel_first
570
+ out_channel_first = ctx.out_channel_first
571
+ one_by_one = ctx.one_by_one
572
+ scans = ctx.scans
573
+ B, C, H, W = ctx.shape
574
+ BC, BH, BW, NC, NH, NW = ctx.triton_shape
575
+ y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
576
+ triton_cross_scan_flex[(NH * NW, NC, B)](
577
+ x.contiguous(), y,
578
+ (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
579
+ BC, BH, BW, C, H, W, NH, NW
580
+ )
581
+ return y, None, None, None, None, None
582
+
583
+
584
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
585
+ def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
586
+ # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
587
+ # y: (B, 4, C, L) | (B, L, 4, C)
588
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
589
+ CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
590
+ if x.is_cuda:
591
+ with torch.cuda.device(x.device):
592
+ return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
593
+ else:
594
+ return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
595
+
596
+
597
+ # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
598
+ def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
599
+ # y: (B, 4, C, L) | (B, L, 4, C)
600
+ # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
601
+ # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
602
+ CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
603
+ if y.is_cuda:
604
+ with torch.cuda.device(y.device):
605
+ return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
606
+ else:
607
+ return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
608
+
609
+
610
+ ##########################################################
611
+ # csms6s.py
612
+ ##########################################################
613
+
614
+ WITH_SELECTIVESCAN_MAMBA = True
615
+ try:
616
+ import selective_scan_cuda
617
+ except ImportError:
618
+ WITH_SELECTIVESCAN_MAMBA = False
619
+
620
+
621
+ def selective_scan_torch(
622
+ u: torch.Tensor, # (B, K * C, L)
623
+ delta: torch.Tensor, # (B, K * C, L)
624
+ A: torch.Tensor, # (K * C, N)
625
+ B: torch.Tensor, # (B, K, N, L)
626
+ C: torch.Tensor, # (B, K, N, L)
627
+ D: torch.Tensor = None, # (K * C)
628
+ delta_bias: torch.Tensor = None, # (K * C)
629
+ delta_softplus=True,
630
+ oflex=True,
631
+ *args,
632
+ **kwargs
633
+ ):
634
+ dtype_in = u.dtype
635
+ Batch, K, N, L = B.shape
636
+ KCdim = u.shape[1]
637
+ Cdim = int(KCdim / K)
638
+ assert u.shape == (Batch, KCdim, L)
639
+ assert delta.shape == (Batch, KCdim, L)
640
+ assert A.shape == (KCdim, N)
641
+ assert C.shape == B.shape
642
+
643
+ if delta_bias is not None:
644
+ delta = delta + delta_bias[..., None]
645
+ if delta_softplus:
646
+ delta = torch.nn.functional.softplus(delta)
647
+
648
+ u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
649
+ B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
650
+ C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
651
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
652
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
653
+
654
+ if True:
655
+ x = A.new_zeros((Batch, KCdim, N))
656
+ ys = []
657
+ for i in range(L):
658
+ x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
659
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
660
+ ys.append(y)
661
+ y = torch.stack(ys, dim=2) # (B, C, L)
662
+
663
+ out = y if D is None else y + u * D.unsqueeze(-1)
664
+ return out if oflex else out.to(dtype=dtype_in)
665
+
666
+
667
+ class SelectiveScanCuda(torch.autograd.Function):
668
+ @staticmethod
669
+ @torch.cuda.amp.custom_fwd
670
+ def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
671
+ ctx.delta_softplus = delta_softplus
672
+ # backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
673
+ # backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
674
+ backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
675
+ ctx.backend = backend
676
+ if backend == "oflex":
677
+ out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
678
+ elif backend == "mamba":
679
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
680
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
681
+ return out
682
+
683
+ @staticmethod
684
+ @torch.cuda.amp.custom_bwd
685
+ def backward(ctx, dout, *args):
686
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
687
+ backend = ctx.backend
688
+ if dout.stride(-1) != 1:
689
+ dout = dout.contiguous()
690
+ if backend == "oflex":
691
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
692
+ u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
693
+ )
694
+ elif backend == "mamba":
695
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
696
+ u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
697
+ False
698
+ )
699
+ return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
700
+
701
+
702
+ def selective_scan_fn(
703
+ u: torch.Tensor, # (B, K * C, L)
704
+ delta: torch.Tensor, # (B, K * C, L)
705
+ A: torch.Tensor, # (K * C, N)
706
+ B: torch.Tensor, # (B, K, N, L)
707
+ C: torch.Tensor, # (B, K, N, L)
708
+ D: torch.Tensor = None, # (K * C)
709
+ delta_bias: torch.Tensor = None, # (K * C)
710
+ delta_softplus=True,
711
+ oflex=True,
712
+ backend=None,
713
+ ):
714
+ fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply
715
+ return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
716
+
717
+ ##########################################################
718
+ ############## HuggingFace modeling file #################
719
+ ##########################################################
720
+
721
+ class DASSLinear2d(nn.Linear):
722
+ def __init__(self, *args, groups=1, **kwargs):
723
+ nn.Linear.__init__(self, *args, **kwargs)
724
+ self.groups = groups
725
+
726
+ def forward(self, x: torch.Tensor):
727
+ if len(x.shape) == 4:
728
+ return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups)
729
+ elif len(x.shape) == 3:
730
+ return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups)
731
+
732
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
733
+ self_state_dict = self.state_dict()
734
+ load_state_dict_keys = list(state_dict.keys())
735
+ if prefix + "weight" in load_state_dict_keys:
736
+ state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"])
737
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
738
+
739
+
740
+ class DASSLayerNorm2d(nn.LayerNorm):
741
+ def __init__(self, *args, **kwargs):
742
+ nn.LayerNorm.__init__(self, *args, **kwargs)
743
+
744
+ def forward(self, x: torch.Tensor):
745
+ x = x.permute(0, 2, 3, 1)
746
+ x = nn.LayerNorm.forward(self, x)
747
+ x = x.permute(0, 3, 1, 2)
748
+ return x
749
+
750
+
751
+ class DASSPatchEmbeddings(nn.Module):
752
+ """
753
+ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
754
+ seq_length, hidden_size)` to be consumed by a State-space model.
755
+ """
756
+
757
+ def __init__(self, patch_size=4,embed_dim=96):
758
+ super().__init__()
759
+
760
+ stride = patch_size // 2
761
+ kernel_size = stride + 1
762
+ padding = 1
763
+
764
+ self.projection = nn.Sequential(
765
+ nn.Conv2d(1, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
766
+ DASSLayerNorm2d(embed_dim // 2),
767
+ nn.GELU(),
768
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
769
+ DASSLayerNorm2d(embed_dim),
770
+ )
771
+
772
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
773
+ x = x.unsqueeze(1)
774
+ x = x.transpose(2, 3)
775
+ x = self.projection(x)
776
+ return x
777
+
778
+
779
+ class DASSDowsample(nn.Module):
780
+ """
781
+ This class downsamples the input tensor using a convolutional layer followed by a layer normalization.
782
+ """
783
+ def __init__(self, dim, out_dim, use_norm=True):
784
+ super().__init__()
785
+ self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1)
786
+ self.norm = DASSLayerNorm2d(out_dim) if use_norm else nn.Identity()
787
+
788
+ def forward(self, x):
789
+ x = self.down(x)
790
+ x = self.norm(x)
791
+ return x
792
+
793
+
794
+ class DASSMlp(nn.Module):
795
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
796
+ super().__init__()
797
+ out_features = out_features or in_features
798
+ hidden_features = hidden_features or in_features
799
+ self.fc1 = DASSLinear2d(in_features, hidden_features)
800
+ self.act = act_layer()
801
+ self.fc2 = DASSLinear2d(hidden_features, out_features)
802
+ self.drop = nn.Dropout(drop)
803
+
804
+ def forward(self, x):
805
+ x = self.fc1(x)
806
+ x = self.act(x)
807
+ x = self.drop(x)
808
+ x = self.fc2(x)
809
+ x = self.drop(x)
810
+ return x
811
+
812
+
813
+ class SS2D(nn.Module):
814
+ def __init__(
815
+ self,
816
+ # basic dims ===========
817
+ d_model=96,
818
+ d_state=16,
819
+ ssm_ratio=2.0,
820
+ dt_rank="auto",
821
+ act_layer=nn.SiLU,
822
+ # dwconv ===============
823
+ d_conv=3,
824
+ conv_bias=True,
825
+ # ======================
826
+ dropout=0.0,
827
+ bias=False,
828
+ # dt init ==============
829
+ dt_min=0.001,
830
+ dt_max=0.1,
831
+ dt_init="random",
832
+ dt_scale=1.0,
833
+ dt_init_floor=1e-4,
834
+ # forward_type="v05_noz" is always used
835
+ # ======================
836
+ **kwargs,
837
+ ):
838
+ super().__init__()
839
+ self.k_group = 4
840
+ self.d_model = int(d_model)
841
+ self.d_state = int(d_state)
842
+ self.d_inner = int(ssm_ratio * d_model)
843
+ self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
844
+ self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True)
845
+ self.with_dconv = d_conv > 1
846
+
847
+ # In projection
848
+ self.in_proj = DASSLinear2d(self.d_model, self.d_inner, bias=bias)
849
+ self.act: nn.Module = act_layer()
850
+
851
+ # Convolution
852
+ if self.with_dconv:
853
+ self.conv2d = nn.Conv2d(
854
+ in_channels=self.d_inner,
855
+ out_channels=self.d_inner,
856
+ groups=self.d_inner,
857
+ bias=conv_bias,
858
+ kernel_size=d_conv,
859
+ padding=(d_conv - 1) // 2,
860
+ )
861
+
862
+ # x_proj and dt_proj
863
+ self.x_proj = DASSLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False)
864
+ self.dt_projs = DASSLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False)
865
+
866
+ # out projection
867
+ self.out_proj = DASSLinear2d(self.d_inner, self.d_model, bias=bias)
868
+ self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
869
+
870
+ # Initialization
871
+ self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D(
872
+ self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
873
+ )
874
+ self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape)
875
+ # self.dt_projs.bias.data = self.dt_projs_bias.data.view(self.dt_projs.bias.shape)
876
+ del self.dt_projs_weight
877
+ # del self.dt_projs_bias
878
+ # Define out_norm directly with "LN2D"
879
+ self.out_norm = DASSLayerNorm2d(self.d_inner)
880
+
881
+ @staticmethod
882
+ def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
883
+ dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
884
+
885
+ dt_init_std = dt_rank**-0.5 * dt_scale
886
+ if dt_init == "constant":
887
+ nn.init.constant_(dt_proj.weight, dt_init_std)
888
+ elif dt_init == "random":
889
+ nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
890
+ else:
891
+ raise NotImplementedError
892
+
893
+ dt = torch.exp(
894
+ torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
895
+ + math.log(dt_min)
896
+ ).clamp(min=dt_init_floor)
897
+
898
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
899
+ with torch.no_grad():
900
+ dt_proj.bias.copy_(inv_dt)
901
+
902
+ return dt_proj
903
+
904
+ @staticmethod
905
+ def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
906
+ A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
907
+ A_log = torch.log(A)
908
+ if copies > 0:
909
+ A_log = A_log[None].repeat(copies, 1, 1).contiguous()
910
+ if merge:
911
+ A_log = A_log.flatten(0, 1)
912
+ A_log = nn.Parameter(A_log)
913
+ A_log._no_weight_decay = True
914
+ return A_log
915
+
916
+ @staticmethod
917
+ def D_init(d_inner, copies=-1, device=None, merge=True):
918
+ D = torch.ones(d_inner, device=device)
919
+ if copies > 0:
920
+ D = D[None].repeat(copies, 1).contiguous()
921
+ if merge:
922
+ D = D.flatten(0, 1)
923
+ D = nn.Parameter(D)
924
+ D._no_weight_decay = True
925
+ return D
926
+
927
+ @classmethod
928
+ def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
929
+ dt_projs = [
930
+ cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
931
+ for _ in range(k_group)
932
+ ]
933
+ dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0))
934
+ dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0))
935
+ del dt_projs
936
+
937
+ A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True)
938
+ Ds = cls.D_init(d_inner, copies=k_group, merge=True)
939
+ return A_logs, Ds, dt_projs_weight, dt_projs_bias
940
+
941
+ def forward_corev2(
942
+ self,
943
+ x: torch.Tensor,
944
+ force_fp32=False,
945
+ no_einsum=True,
946
+ ):
947
+ B, D, H, W = x.shape
948
+ N = self.d_state
949
+ L = H * W
950
+
951
+ xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True)
952
+ x_dbl = self.x_proj(xs.view(B, -1, L))
953
+ dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2)
954
+ dts = dts.contiguous().view(B, -1, L)
955
+ dts = self.dt_projs(dts)
956
+
957
+ xs = xs.view(B, -1, L)
958
+ dts = dts.contiguous().view(B, -1, L)
959
+ As = -self.A_logs.to(torch.float32).exp()
960
+ Ds = self.Ds.to(torch.float32)
961
+ Bs = Bs.contiguous().view(B, self.k_group, N, L)
962
+ Cs = Cs.contiguous().view(B, self.k_group, N, L)
963
+ delta_bias = self.dt_projs_bias.view(-1).to(torch.float32)
964
+
965
+ ys = selective_scan_fn(
966
+ xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba"
967
+ ).view(B, self.k_group, -1, H, W)
968
+
969
+ y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True)
970
+ y = y.view(B, -1, H, W)
971
+ y = self.out_norm(y)
972
+ return y.to(x.dtype)
973
+
974
+ def forward(self, x: torch.Tensor):
975
+ x = self.in_proj(x)
976
+ x = self.conv2d(x)
977
+
978
+ x = self.act(x)
979
+ y = self.forward_core(x)
980
+
981
+ out = self.dropout(self.out_proj(y))
982
+ return out
983
+
984
+
985
+ class VSSBlock(nn.Module):
986
+ def __init__(
987
+ self,
988
+ hidden_dim: int = 0,
989
+ drop_path: float = 0,
990
+ ssm_d_state: int = 1,
991
+ ssm_ratio=1.0,
992
+ ssm_dt_rank: Any = "auto",
993
+ ssm_act_layer=nn.SiLU,
994
+ ssm_conv: int = 3,
995
+ ssm_conv_bias=False,
996
+ ssm_drop_rate: float = 0,
997
+ mlp_ratio=4.0,
998
+ mlp_act_layer=nn.GELU,
999
+ mlp_drop_rate: float = 0.0,
1000
+ use_checkpoint: bool = False,
1001
+ post_norm: bool = False,
1002
+ **kwargs,
1003
+ ):
1004
+ super().__init__()
1005
+ self.ssm_branch = ssm_ratio > 0
1006
+ self.mlp_branch = mlp_ratio > 0
1007
+ self.use_checkpoint = use_checkpoint
1008
+ self.post_norm = post_norm
1009
+
1010
+ if self.ssm_branch:
1011
+ self.norm = DASSLayerNorm2d(hidden_dim)
1012
+ self.op = SS2D(
1013
+ d_model=hidden_dim,
1014
+ d_state=ssm_d_state,
1015
+ ssm_ratio=ssm_ratio,
1016
+ dt_rank=ssm_dt_rank,
1017
+ act_layer=ssm_act_layer,
1018
+ d_conv=ssm_conv,
1019
+ conv_bias=ssm_conv_bias,
1020
+ dropout=ssm_drop_rate,
1021
+ )
1022
+
1023
+ self.drop_path = DropPath(drop_path)
1024
+
1025
+ if self.mlp_branch:
1026
+ self.norm2 = DASSLayerNorm2d(hidden_dim)
1027
+ mlp_hidden_dim = int(hidden_dim * mlp_ratio)
1028
+ self.mlp = DASSMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate)
1029
+
1030
+ def _forward(self, input: torch.Tensor):
1031
+ x = input
1032
+ if self.ssm_branch:
1033
+ if self.post_norm:
1034
+ x = x + self.drop_path(self.norm(self.op(x)))
1035
+ else:
1036
+ x = x + self.drop_path(self.op(self.norm(x)))
1037
+ if self.mlp_branch:
1038
+ if self.post_norm:
1039
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
1040
+ else:
1041
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1042
+ return x
1043
+
1044
+ def forward(self, input: torch.Tensor):
1045
+ if self.use_checkpoint:
1046
+ return checkpoint.checkpoint(self._forward, input)
1047
+ else:
1048
+ return self._forward(input)
1049
+
1050
+ class DASSLayer(nn.Module):
1051
+
1052
+ def __init__(
1053
+ self,
1054
+ input_dim,
1055
+ depth,
1056
+ drop_path=0.0,
1057
+ norm_layer=DASSLayerNorm2d,
1058
+ downsample=nn.Identity(),
1059
+ use_checkpoint=False,
1060
+ **kwargs,
1061
+ ):
1062
+ super().__init__()
1063
+ self.input_dim = input_dim
1064
+ self.use_checkpoint = use_checkpoint
1065
+
1066
+ self.blocks = nn.ModuleList()
1067
+ for i in range(depth):
1068
+ self.blocks.append(
1069
+ VSSBlock(hidden_dim=input_dim,
1070
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
1071
+ norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs,
1072
+ )
1073
+ )
1074
+
1075
+ self.downsample = downsample
1076
+
1077
+ def forward(self, x):
1078
+ for block in self.blocks:
1079
+ x = block(x)
1080
+
1081
+ x = self.downsample(x)
1082
+ return x
1083
+
1084
+ class DASSPreTrainedModel(PreTrainedModel):
1085
+ """
1086
+ An abstract class to handle weights initialization and
1087
+ a simple interface for downloading and loading pretrained models.
1088
+ """
1089
+
1090
+ config_class = DASSConfig
1091
+ base_model_prefix = "dass"
1092
+ supports_gradient_checkpointing = False
1093
+
1094
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
1095
+ """Initialize the weights"""
1096
+ if isinstance(module, nn.Linear):
1097
+ trunc_normal_(module.weight, std=0.02)
1098
+ if isinstance(module, nn.Linear) and module.bias is not None:
1099
+ nn.init.constant_(module.bias, 0)
1100
+ elif isinstance(module, nn.LayerNorm):
1101
+ nn.init.constant_(module.bias, 0)
1102
+ nn.init.constant_(module.weight, 1.0)
1103
+
1104
+
1105
+ class DASSModel(DASSPreTrainedModel):
1106
+ def __init__(self, config):
1107
+ super().__init__(config)
1108
+ self.config = config
1109
+
1110
+ dims = config.dims
1111
+ if isinstance(dims, int):
1112
+ dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)]
1113
+
1114
+ self.dims = dims
1115
+ self.patch_embeddings = DASSPatchEmbeddings(patch_size=config.patch_size,
1116
+ embed_dim=dims[0])
1117
+
1118
+ self.num_layers = len(config.depths)
1119
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
1120
+ self.num_features = dims[-1]
1121
+
1122
+ self.layers = nn.ModuleList()
1123
+ for i in range(self.num_layers):
1124
+ layer = DASSLayer(
1125
+ input_dim=self.dims[i],
1126
+ depth=config.depths[i],
1127
+ drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])],
1128
+ downsample=DASSDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(),
1129
+ use_checkpoint=config.use_checkpoint,
1130
+ )
1131
+ self.layers.append(layer)
1132
+
1133
+ self.norm = DASSLayerNorm2d(self.num_features)
1134
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
1135
+
1136
+ def get_input_embeddings(self) -> DASSPatchEmbeddings:
1137
+ return self.patch_embeddings
1138
+
1139
+ def forward(self, input_values: torch.Tensor):
1140
+ x = self.patch_embeddings(input_values)
1141
+ for layer in self.layers:
1142
+ x = layer(x)
1143
+ x = self.norm(x)
1144
+ x = self.avgpool(x).flatten(1)
1145
+ return x
1146
+
1147
+
1148
+ class DASSForAudioClassification(DASSPreTrainedModel):
1149
+ def __init__(self, config):
1150
+ super().__init__(config)
1151
+
1152
+ self.num_classes = config.num_classes
1153
+ self.dass = DASSModel(config)
1154
+ self.head = nn.Linear(self.dass.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
1155
+
1156
+ # Initialize weights and apply final processing
1157
+ self.post_init()
1158
+
1159
+ def forward(
1160
+ self,
1161
+ input_values: Optional[torch.Tensor] = None,
1162
+ labels: Optional[torch.Tensor] = None,
1163
+ return_dict: Optional[bool] = None,
1164
+ ):
1165
+
1166
+ outputs = self.dass(
1167
+ input_values,
1168
+ )
1169
+
1170
+ logits = self.head(outputs)
1171
+
1172
+ loss = None
1173
+ if labels is not None:
1174
+ labels = labels.to(logits.device)
1175
+ if self.config.loss_type == "ce":
1176
+ loss_fct = CrossEntropyLoss()
1177
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1178
+ elif self.config.problem_type == "bce":
1179
+ loss_fct = BCEWithLogitsLoss()
1180
+ loss = loss_fct(logits, labels)
1181
+
1182
+ if return_dict:
1183
+ output = (logits,) + (outputs,)
1184
+ return ((loss,) + output) if loss is not None else output
1185
+
1186
+ return SequenceClassifierOutput(
1187
+ loss=loss,
1188
+ logits=logits,
1189
+ hidden_states=outputs,
1190
+ )
1191
+
1192
+ __all__ = [
1193
+ "DASSModel",
1194
+ "DASSPreTrainedModel",
1195
+ "DASSForAudioClassification",
1196
+ ]