Xalphinions commited on
Commit
088f2ca
·
verified ·
1 Parent(s): dd995d1

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. moe_evaluation_results.json +705 -705
  2. requirements.txt +7 -7
  3. test_moe_model.py +9 -4
moe_evaluation_results.json CHANGED
@@ -1,801 +1,801 @@
1
  {
2
- "moe_test_mae": 0.19680618420243262,
3
- "moe_test_mse": 0.05606407420709729,
4
  "true_labels": [
5
- 10.5,
6
- 9.399999618530273,
7
- 11.600000381469727,
8
- 8.699999809265137,
9
- 10.399999618530273,
10
- 10.800000190734863,
11
- 11.600000381469727,
12
- 10.5,
13
- 11.600000381469727,
14
  11.100000381469727,
15
- 10.399999618530273,
16
- 10.5,
17
- 11.0,
18
- 10.5,
19
- 10.899999618530273,
20
- 10.5,
21
- 11.100000381469727,
22
- 9.600000381469727,
23
- 12.699999809265137,
24
- 10.0,
25
- 10.300000190734863,
26
- 10.399999618530273,
27
  9.399999618530273,
28
- 10.800000190734863,
29
- 10.0,
30
- 11.600000381469727,
31
- 10.0,
32
- 10.399999618530273,
33
  9.399999618530273,
34
- 10.399999618530273,
35
  10.300000190734863,
 
 
 
36
  9.399999618530273,
37
- 10.899999618530273,
38
- 9.0,
39
  10.300000190734863,
 
40
  10.899999618530273,
 
41
  11.0,
42
  12.699999809265137,
43
- 10.399999618530273,
44
- 9.600000381469727,
45
  8.699999809265137,
46
- 10.199999809265137,
47
- 10.300000190734863,
48
  11.600000381469727,
49
- 9.0,
50
- 9.0,
51
  11.0,
52
- 8.699999809265137,
53
- 9.699999809265137,
54
- 10.399999618530273,
55
- 10.0,
56
- 11.600000381469727,
57
- 9.399999618530273,
58
- 9.0,
59
  10.300000190734863,
 
 
 
60
  10.5,
 
 
 
61
  10.399999618530273,
62
- 11.0,
63
- 10.899999618530273,
64
- 9.399999618530273,
65
- 8.699999809265137,
66
  10.300000190734863,
67
  9.699999809265137,
68
- 10.300000190734863,
 
 
 
 
69
  9.399999618530273,
70
  10.300000190734863,
71
  9.399999618530273,
72
- 10.0,
73
- 10.399999618530273,
74
- 10.199999809265137,
75
  11.0,
76
  12.699999809265137,
77
- 12.699999809265137,
78
- 10.0,
79
  11.0,
80
- 9.0,
81
- 10.0,
 
 
 
 
 
 
 
 
82
  10.5,
83
  11.600000381469727,
 
 
 
84
  9.399999618530273,
 
85
  10.0,
86
- 11.0,
87
- 11.100000381469727,
88
- 10.899999618530273,
89
  9.399999618530273,
 
90
  10.300000190734863,
91
- 9.399999618530273,
92
  8.699999809265137,
93
- 10.0,
94
  12.699999809265137,
95
- 12.699999809265137,
96
- 9.699999809265137,
97
  9.399999618530273,
98
- 11.0,
 
99
  9.399999618530273,
 
 
 
100
  9.0,
 
 
 
 
 
 
101
  11.100000381469727,
102
- 10.300000190734863,
103
- 10.300000190734863,
104
- 10.300000190734863,
105
- 10.0,
106
  9.399999618530273,
107
  9.399999618530273,
108
  10.899999618530273,
 
 
 
109
  11.0,
110
  9.699999809265137,
111
- 12.699999809265137,
112
- 10.5,
 
 
 
 
113
  11.0,
114
- 10.899999618530273,
115
  12.699999809265137,
116
- 10.899999618530273,
117
  11.0,
118
- 10.300000190734863,
119
- 11.0,
120
- 9.699999809265137,
121
  10.300000190734863,
122
  10.300000190734863,
123
- 10.199999809265137,
124
- 10.199999809265137,
125
- 10.899999618530273,
126
- 10.5,
127
  11.0,
128
- 8.699999809265137,
129
  9.699999809265137,
 
130
  12.699999809265137,
131
- 11.600000381469727,
132
- 10.899999618530273,
133
  11.0,
134
  9.399999618530273,
135
- 10.300000190734863,
136
- 12.699999809265137,
137
- 10.199999809265137,
138
- 10.199999809265137,
139
- 10.800000190734863,
 
 
 
140
  8.699999809265137,
141
- 9.0,
142
  11.0,
143
- 9.399999618530273,
 
144
  10.800000190734863,
145
- 11.100000381469727,
146
- 11.100000381469727,
147
- 10.199999809265137,
148
  9.399999618530273,
149
- 10.199999809265137,
150
- 10.199999809265137,
151
  9.399999618530273,
152
  10.899999618530273,
153
  10.199999809265137,
 
 
 
 
154
  11.100000381469727,
155
- 11.600000381469727,
156
  8.699999809265137,
 
157
  11.600000381469727,
158
- 10.199999809265137,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  9.399999618530273,
160
  9.699999809265137,
161
- 9.399999618530273
 
162
  ],
163
  "moe_predictions": [
164
- 10.906482696533203,
165
- 9.413387298583984,
166
- 11.58445930480957,
167
- 8.627098083496094,
168
- 10.55517578125,
169
- 10.969362258911133,
170
- 11.596641540527344,
171
- 10.598587036132812,
172
- 11.712945938110352,
173
- 11.415390968322754,
174
- 10.500967979431152,
175
- 10.939116477966309,
176
- 11.23089599609375,
177
- 10.928877830505371,
178
- 11.180931091308594,
179
- 10.805574417114258,
180
- 11.44560432434082,
181
- 9.797750473022461,
182
- 12.00424575805664,
183
- 9.924805641174316,
184
- 10.419149398803711,
185
- 10.459878921508789,
186
- 9.774242401123047,
187
- 10.985288619995117,
188
- 10.047812461853027,
189
- 11.745304107666016,
190
- 10.191004753112793,
191
- 10.527164459228516,
192
- 9.581968307495117,
193
- 10.483012199401855,
194
- 10.368606567382812,
195
- 9.450727462768555,
196
- 11.197010040283203,
197
- 9.173027038574219,
198
- 10.50676441192627,
199
- 11.195816040039062,
200
- 11.227279663085938,
201
- 13.106525421142578,
202
- 10.4664945602417,
203
- 9.891031265258789,
204
- 8.75540542602539,
205
- 10.572815895080566,
206
- 10.214585304260254,
207
- 12.000329971313477,
208
- 8.887301445007324,
209
- 8.929031372070312,
210
- 11.054266929626465,
211
- 8.85447883605957,
212
- 9.515145301818848,
213
- 10.480228424072266,
214
- 10.193933486938477,
215
- 11.7305908203125,
216
- 9.437666893005371,
217
- 9.13387680053711,
218
- 10.629348754882812,
219
- 10.703892707824707,
220
- 10.539461135864258,
221
  11.135326385498047,
222
- 11.19705867767334,
223
- 9.558942794799805,
224
- 8.898516654968262,
225
- 10.628425598144531,
226
- 9.657480239868164,
227
- 10.513351440429688,
228
- 9.459192276000977,
229
- 10.358184814453125,
230
- 9.432706832885742,
231
- 10.078161239624023,
232
- 10.572355270385742,
233
- 10.58112907409668,
234
- 10.910698890686035,
235
- 13.053973197937012,
236
- 12.972726821899414,
237
- 10.170805931091309,
238
- 11.225208282470703,
239
- 8.872610092163086,
240
- 10.091118812561035,
241
- 10.724177360534668,
242
- 11.729219436645508,
243
- 9.66834545135498,
244
- 10.027229309082031,
245
- 11.232885360717773,
246
- 11.518696784973145,
247
- 11.261479377746582,
248
- 9.523242950439453,
249
- 10.484042167663574,
250
- 9.522797584533691,
251
- 8.75236988067627,
252
- 10.083819389343262,
253
- 13.073421478271484,
254
- 13.001571655273438,
255
- 9.905550003051758,
256
  9.318197250366211,
257
- 11.141549110412598,
258
- 9.754105567932129,
259
- 9.013923645019531,
260
- 11.429242134094238,
261
- 10.375783920288086,
262
- 10.526394844055176,
263
- 10.307140350341797,
264
- 10.169934272766113,
265
- 9.429258346557617,
266
- 9.29328441619873,
 
 
267
  11.136444091796875,
268
- 11.040485382080078,
269
- 9.723966598510742,
270
- 12.936074256896973,
271
- 10.913898468017578,
272
- 11.255935668945312,
273
- 11.032815933227539,
274
- 12.95362663269043,
275
- 10.942233085632324,
276
- 11.014484405517578,
277
- 10.47386646270752,
278
- 11.207697868347168,
279
- 9.531013488769531,
280
- 10.512401580810547,
281
- 10.791257858276367,
282
- 10.385677337646484,
283
- 10.393269538879395,
284
- 11.13322639465332,
 
 
 
 
 
 
 
285
  10.893503189086914,
286
- 11.24067497253418,
287
- 8.767911911010742,
288
- 9.76015853881836,
289
- 13.095734596252441,
290
- 11.651636123657227,
291
- 11.08572006225586,
292
- 10.958650588989258,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  9.548912048339844,
294
- 10.243309020996094,
295
- 13.102086067199707,
296
- 10.579414367675781,
297
- 10.406577110290527,
298
- 11.255165100097656,
299
- 8.494292259216309,
300
- 8.890151023864746,
301
- 11.146952629089355,
302
- 9.766341209411621,
303
- 11.163339614868164,
304
- 11.502073287963867,
305
- 11.408285140991211,
306
- 10.383015632629395,
307
- 9.54578971862793,
308
- 10.56948184967041,
309
- 10.558614730834961,
310
- 9.794357299804688,
311
- 10.885274887084961,
312
- 10.377969741821289,
313
- 11.410195350646973,
314
- 11.537992477416992,
315
- 8.826037406921387,
316
- 12.070415496826172,
317
- 10.559798240661621,
318
- 9.605077743530273,
319
- 9.737533569335938,
320
- 9.520374298095703
321
  ],
322
  "individual_predictions": {
323
  "efficientnet_b3_transformer": [
324
- 10.619565963745117,
325
- 9.285565376281738,
326
- 11.017762184143066,
327
- 8.358080863952637,
328
- 9.92147159576416,
329
- 10.68340015411377,
330
- 11.023524284362793,
331
- 10.292417526245117,
332
- 10.513864517211914,
333
- 10.958821296691895,
334
- 10.322061538696289,
335
- 10.383071899414062,
336
- 10.330121040344238,
337
- 10.344510078430176,
338
- 11.309442520141602,
339
- 10.321882247924805,
340
- 10.974185943603516,
341
- 9.367315292358398,
342
- 11.474529266357422,
343
- 9.296891212463379,
344
- 10.27892780303955,
345
- 10.14356803894043,
346
- 9.155308723449707,
347
- 10.249421119689941,
348
- 9.534292221069336,
349
- 11.197205543518066,
350
- 9.988767623901367,
351
- 10.485107421875,
352
- 9.040623664855957,
353
- 10.171326637268066,
354
- 10.153056144714355,
355
- 9.17545223236084,
356
- 10.604523658752441,
357
- 8.7711763381958,
358
- 10.127464294433594,
359
- 11.29480266571045,
360
- 10.326626777648926,
361
- 13.54947566986084,
362
- 10.142123222351074,
363
- 9.914827346801758,
364
- 7.935253620147705,
365
- 10.513096809387207,
366
- 9.79228687286377,
367
- 11.721403121948242,
368
- 7.996966361999512,
369
- 8.011720657348633,
370
- 10.551737785339355,
371
- 8.663973808288574,
372
- 8.74413776397705,
373
- 10.276195526123047,
374
- 10.136805534362793,
375
- 11.221556663513184,
376
- 8.912840843200684,
377
- 8.619383811950684,
378
- 10.178643226623535,
379
- 10.311914443969727,
380
- 10.487189292907715,
381
  10.548056602478027,
382
- 11.258485794067383,
383
- 9.288726806640625,
384
- 8.140922546386719,
385
- 10.216073989868164,
386
- 9.068129539489746,
387
- 10.33917236328125,
388
- 9.11395263671875,
389
- 10.140262603759766,
390
- 8.864439010620117,
391
- 9.560175895690918,
392
- 10.1554594039917,
393
- 10.011631965637207,
394
- 10.838635444641113,
395
- 13.890799522399902,
396
- 13.743374824523926,
397
- 10.119439125061035,
398
- 11.073603630065918,
399
- 7.99126672744751,
400
- 10.012906074523926,
401
- 10.309550285339355,
402
- 10.537038803100586,
403
- 9.361739158630371,
404
- 9.594813346862793,
405
- 10.32430362701416,
406
- 11.0283842086792,
407
- 11.271435737609863,
408
- 9.267289161682129,
409
- 10.143651962280273,
410
- 9.201630592346191,
411
- 8.489853858947754,
412
- 9.663308143615723,
413
- 13.539351463317871,
414
- 13.890753746032715,
415
- 9.300865173339844,
416
  8.978877067565918,
417
- 10.455121994018555,
418
- 9.145268440246582,
419
- 8.390588760375977,
420
- 10.97396183013916,
421
- 10.023279190063477,
422
- 10.194899559020996,
423
- 9.974883079528809,
424
- 10.101761817932129,
425
- 9.511059761047363,
426
- 8.89189624786377,
 
 
427
  10.77907657623291,
428
- 10.7083158493042,
429
- 9.067532539367676,
430
- 13.406800270080566,
431
- 10.60212516784668,
432
- 10.704161643981934,
433
- 11.133363723754883,
434
- 13.293631553649902,
435
- 9.996685981750488,
436
- 10.766114234924316,
437
- 10.15234088897705,
438
- 11.180027961730957,
439
- 8.875227928161621,
440
- 10.376603126525879,
441
- 10.074305534362793,
442
- 10.001667022705078,
443
- 10.027312278747559,
444
- 10.606922149658203,
 
 
 
 
 
 
 
445
  10.565585136413574,
446
- 10.699769020080566,
447
- 8.507576942443848,
448
- 9.084380149841309,
449
- 13.500945091247559,
450
- 11.240296363830566,
451
- 10.65023136138916,
452
- 10.248372077941895,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  9.269180297851562,
454
- 9.840892791748047,
455
- 13.547538757324219,
456
- 9.992758750915527,
457
- 10.026358604431152,
458
- 10.71567440032959,
459
- 8.320480346679688,
460
- 8.000975608825684,
461
- 10.548954963684082,
462
- 9.176098823547363,
463
- 11.098072052001953,
464
- 11.02483081817627,
465
- 11.12319278717041,
466
- 9.996392250061035,
467
- 9.263312339782715,
468
- 10.517735481262207,
469
- 9.8799409866333,
470
- 9.319127082824707,
471
- 9.990796089172363,
472
- 9.982155799865723,
473
- 11.105603218078613,
474
- 10.747210502624512,
475
- 8.343344688415527,
476
- 11.73001480102539,
477
- 10.511062622070312,
478
- 9.331645965576172,
479
- 9.131060600280762,
480
- 8.956952095031738
481
  ],
482
  "efficientnet_b0_transformer": [
483
- 11.040512084960938,
484
- 9.555410385131836,
485
- 11.689399719238281,
486
- 8.434002876281738,
487
- 11.386773109436035,
488
- 10.940624237060547,
489
- 11.708887100219727,
490
- 11.056541442871094,
491
- 12.392988204956055,
492
- 11.619367599487305,
493
- 10.591476440429688,
494
- 11.15828800201416,
495
- 11.810995101928711,
496
- 11.26023006439209,
497
- 11.246732711791992,
498
- 11.448994636535645,
499
- 11.935430526733398,
500
- 10.085470199584961,
501
- 12.768455505371094,
502
- 10.39224910736084,
503
- 10.590924263000488,
504
- 10.642997741699219,
505
- 9.948995590209961,
506
- 11.38804817199707,
507
- 10.38807487487793,
508
- 11.55557632446289,
509
- 10.514514923095703,
510
- 10.37149429321289,
511
- 9.95881462097168,
512
- 10.645825386047363,
513
- 10.480897903442383,
514
- 9.64439868927002,
515
- 11.213277816772461,
516
- 9.551204681396484,
517
- 10.929215431213379,
518
- 11.268585205078125,
519
- 11.799053192138672,
520
- 12.975137710571289,
521
- 10.657550811767578,
522
- 9.907003402709961,
523
- 9.108478546142578,
524
- 10.350242614746094,
525
- 10.475027084350586,
526
- 12.249593734741211,
527
- 9.311214447021484,
528
- 9.402128219604492,
529
- 11.460792541503906,
530
- 8.638538360595703,
531
- 10.098196029663086,
532
- 10.429000854492188,
533
- 10.63322639465332,
534
- 11.521190643310547,
535
- 9.934067726135254,
536
- 9.390719413757324,
537
- 10.85897445678711,
538
- 10.96368408203125,
539
- 10.440620422363281,
540
  11.39995002746582,
541
- 11.138040542602539,
542
- 9.738420486450195,
543
- 9.13027286529541,
544
- 10.834165573120117,
545
- 9.734615325927734,
546
- 10.535043716430664,
547
- 9.7576904296875,
548
- 10.504064559936523,
549
- 9.726502418518066,
550
- 10.391711235046387,
551
- 10.526286125183105,
552
- 10.450986862182617,
553
- 10.732028007507324,
554
- 13.047806739807129,
555
- 12.901583671569824,
556
- 10.609762191772461,
557
- 11.112765312194824,
558
- 9.227752685546875,
559
- 10.403764724731445,
560
- 10.97991943359375,
561
- 12.400298118591309,
562
- 9.740009307861328,
563
- 10.546162605285645,
564
- 11.811308860778809,
565
- 12.024316787719727,
566
- 11.304412841796875,
567
- 9.642568588256836,
568
- 10.770721435546875,
569
- 9.673535346984863,
570
- 8.692492485046387,
571
- 10.140533447265625,
572
- 13.103691101074219,
573
- 12.987236022949219,
574
- 9.978914260864258,
575
  9.647960662841797,
576
- 11.465564727783203,
577
- 9.91793155670166,
578
- 8.99271011352539,
579
- 11.874197959899902,
580
- 10.875059127807617,
581
- 10.751541137695312,
582
- 10.586625099182129,
583
- 10.616861343383789,
584
- 9.251531600952148,
585
- 9.575355529785156,
 
 
586
  11.49870777130127,
587
- 11.352771759033203,
588
- 9.970162391662598,
589
- 12.869828224182129,
590
- 11.021011352539062,
591
- 11.830097198486328,
592
- 10.895241737365723,
593
- 13.477546691894531,
594
- 11.435956001281738,
595
- 11.21767807006836,
596
- 10.8616361618042,
597
- 11.25930404663086,
598
- 9.386629104614258,
599
- 10.510151863098145,
600
- 11.104487419128418,
601
- 10.017858505249023,
602
- 10.365488052368164,
603
- 11.206178665161133,
 
 
 
 
 
 
 
604
  11.027682304382324,
605
- 11.81328010559082,
606
- 8.614967346191406,
607
- 10.088481903076172,
608
- 12.978555679321289,
609
- 11.964248657226562,
610
- 11.287935256958008,
611
- 11.514422416687012,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  9.758452415466309,
613
- 10.500945091247559,
614
- 12.95924186706543,
615
- 10.438175201416016,
616
- 10.364145278930664,
617
- 11.490489959716797,
618
- 8.45285415649414,
619
- 9.380582809448242,
620
- 11.404769897460938,
621
- 10.42972183227539,
622
- 11.568924903869629,
623
- 11.746879577636719,
624
- 11.68482780456543,
625
- 10.019561767578125,
626
- 9.662923812866211,
627
- 10.360588073730469,
628
- 10.901131629943848,
629
- 10.128849029541016,
630
- 11.287601470947266,
631
- 10.017107009887695,
632
- 11.725995063781738,
633
- 11.726645469665527,
634
- 8.865287780761719,
635
- 12.030455589294434,
636
- 10.348114013671875,
637
- 9.747005462646484,
638
- 9.905638694763184,
639
- 9.855661392211914
640
  ],
641
  "resnet50_transformer": [
642
- 11.059370040893555,
643
- 9.399184226989746,
644
- 12.046213150024414,
645
- 9.089208602905273,
646
- 10.357281684875488,
647
- 11.284062385559082,
648
- 12.057510375976562,
649
- 10.44680118560791,
650
- 12.231982231140137,
651
- 11.667984008789062,
652
- 10.58936595916748,
653
- 11.275989532470703,
654
- 11.5515718460083,
655
- 11.181893348693848,
656
- 10.986615180969238,
657
- 10.645844459533691,
658
- 11.427197456359863,
659
- 9.94046688079834,
660
- 11.769749641418457,
661
- 10.08527660369873,
662
- 10.387595176696777,
663
- 10.593070030212402,
664
- 10.218421936035156,
665
- 11.31839656829834,
666
- 10.221070289611816,
667
- 12.48313045501709,
668
- 10.069729804992676,
669
- 10.72489070892334,
670
- 9.746464729309082,
671
- 10.631884574890137,
672
- 10.4718656539917,
673
- 9.532330513000488,
674
- 11.773228645324707,
675
- 9.196700096130371,
676
- 10.46361255645752,
677
- 11.024060249328613,
678
- 11.556159019470215,
679
- 12.794964790344238,
680
- 10.599808692932129,
681
- 9.851262092590332,
682
- 9.222484588623047,
683
- 10.855106353759766,
684
- 10.37644100189209,
685
- 12.02999210357666,
686
- 9.35372257232666,
687
- 9.37324333190918,
688
- 11.150269508361816,
689
- 9.2609224319458,
690
- 9.703102111816406,
691
- 10.735487937927246,
692
- 9.811766624450684,
693
- 12.44902515411377,
694
- 9.46609115600586,
695
- 9.391528129577637,
696
- 10.850428581237793,
697
- 10.836078643798828,
698
- 10.690573692321777,
699
  11.45797348022461,
700
- 11.194649696350098,
701
- 9.649679183959961,
702
- 9.42435359954834,
703
- 10.835038185119629,
704
- 10.169693946838379,
705
- 10.665839195251465,
706
- 9.50593376159668,
707
- 10.43022632598877,
708
- 9.70718002319336,
709
- 10.282594680786133,
710
- 11.035321235656738,
711
- 11.280767440795898,
712
- 11.161433219909668,
713
- 12.223311424255371,
714
- 12.273221015930176,
715
- 9.783215522766113,
716
- 11.48925495147705,
717
- 9.398808479309082,
718
- 9.856684684753418,
719
- 10.883062362670898,
720
- 12.250321388244629,
721
- 9.903286933898926,
722
- 9.940712928771973,
723
- 11.563044548034668,
724
- 11.503388404846191,
725
- 11.208588600158691,
726
- 9.659869194030762,
727
- 10.537753105163574,
728
- 9.693224906921387,
729
- 9.074763298034668,
730
- 10.447615623474121,
731
- 12.577223777770996,
732
- 12.126725196838379,
733
- 10.436871528625488,
734
  9.327754020690918,
735
- 11.503960609436035,
736
- 10.199116706848145,
737
- 9.658470153808594,
738
- 11.43956470489502,
739
- 10.229013442993164,
740
- 10.632741928100586,
741
- 10.35991096496582,
742
- 9.791178703308105,
743
- 9.52518367767334,
744
- 9.412601470947266,
 
 
745
  11.131546974182129,
746
- 11.0603666305542,
747
- 10.13420295715332,
748
- 12.53159236907959,
749
- 11.118557929992676,
750
- 11.233548164367676,
751
- 11.069842338562012,
752
- 12.089702606201172,
753
- 11.394057273864746,
754
- 11.059659957885742,
755
- 10.407622337341309,
756
- 11.183761596679688,
757
- 10.331181526184082,
758
- 10.6504487991333,
759
- 11.194979667663574,
760
- 11.137504577636719,
761
- 10.787008285522461,
762
- 11.586577415466309,
 
 
 
 
 
 
 
763
  11.08724308013916,
764
- 11.208975791931152,
765
- 9.181191444396973,
766
- 10.107614517211914,
767
- 12.807703018188477,
768
- 11.750362396240234,
769
- 11.31899356842041,
770
- 11.11315631866455,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  9.619100570678711,
772
- 10.388087272644043,
773
- 12.79947566986084,
774
- 11.307307243347168,
775
- 10.82922649383545,
776
- 11.55932903289795,
777
- 8.709542274475098,
778
- 9.288893699645996,
779
- 11.48713207244873,
780
- 9.693202018737793,
781
- 10.82302188873291,
782
- 11.73450756072998,
783
- 11.416834831237793,
784
- 11.133091926574707,
785
- 9.71113109588623,
786
- 10.830121040344238,
787
- 10.894770622253418,
788
- 9.935094833374023,
789
- 11.377425193786621,
790
- 11.13464641571045,
791
- 11.39898681640625,
792
- 12.140122413635254,
793
- 9.269479751586914,
794
- 12.450774192810059,
795
- 10.820216178894043,
796
- 9.736580848693848,
797
- 10.17590045928955,
798
- 9.74850845336914
799
  ]
800
  }
801
  }
 
1
  {
2
+ "moe_test_mae": 0.2067064680159092,
3
+ "moe_test_mse": 0.06013735262677074,
4
  "true_labels": [
 
 
 
 
 
 
 
 
 
5
  11.100000381469727,
 
 
 
 
 
 
 
 
 
 
 
 
6
  9.399999618530273,
 
 
 
 
 
7
  9.399999618530273,
 
8
  10.300000190734863,
9
+ 10.300000190734863,
10
+ 8.699999809265137,
11
+ 9.600000381469727,
12
  9.399999618530273,
 
 
13
  10.300000190734863,
14
+ 12.699999809265137,
15
  10.899999618530273,
16
+ 12.699999809265137,
17
  11.0,
18
  12.699999809265137,
 
 
19
  8.699999809265137,
 
 
20
  11.600000381469727,
 
 
21
  11.0,
22
+ 11.0,
 
 
 
 
 
 
23
  10.300000190734863,
24
+ 11.0,
25
+ 9.600000381469727,
26
+ 11.100000381469727,
27
  10.5,
28
+ 9.699999809265137,
29
+ 9.0,
30
+ 10.199999809265137,
31
  10.399999618530273,
 
 
 
 
32
  10.300000190734863,
33
  9.699999809265137,
34
+ 10.399999618530273,
35
+ 12.699999809265137,
36
+ 9.399999618530273,
37
+ 9.399999618530273,
38
+ 9.600000381469727,
39
  9.399999618530273,
40
  10.300000190734863,
41
  9.399999618530273,
42
+ 10.300000190734863,
 
 
43
  11.0,
44
  12.699999809265137,
45
+ 9.399999618530273,
 
46
  11.0,
47
+ 8.699999809265137,
48
+ 10.800000190734863,
49
+ 10.300000190734863,
50
+ 10.899999618530273,
51
+ 11.0,
52
+ 10.899999618530273,
53
+ 10.300000190734863,
54
+ 11.0,
55
+ 11.100000381469727,
56
+ 9.399999618530273,
57
  10.5,
58
  11.600000381469727,
59
+ 10.300000190734863,
60
+ 9.0,
61
+ 9.399999618530273,
62
  9.399999618530273,
63
+ 11.600000381469727,
64
  10.0,
 
 
 
65
  9.399999618530273,
66
+ 10.399999618530273,
67
  10.300000190734863,
 
68
  8.699999809265137,
 
69
  12.699999809265137,
70
+ 10.300000190734863,
 
71
  9.399999618530273,
72
+ 10.300000190734863,
73
+ 10.199999809265137,
74
  9.399999618530273,
75
+ 9.600000381469727,
76
+ 11.600000381469727,
77
+ 10.5,
78
  9.0,
79
+ 11.0,
80
+ 11.600000381469727,
81
+ 11.0,
82
+ 8.699999809265137,
83
+ 9.399999618530273,
84
+ 12.699999809265137,
85
  11.100000381469727,
 
 
 
 
86
  9.399999618530273,
87
  9.399999618530273,
88
  10.899999618530273,
89
+ 10.300000190734863,
90
+ 9.699999809265137,
91
+ 11.600000381469727,
92
  11.0,
93
  9.699999809265137,
94
+ 8.699999809265137,
95
+ 10.399999618530273,
96
+ 10.300000190734863,
97
+ 10.399999618530273,
98
+ 10.399999618530273,
99
+ 10.199999809265137,
100
  11.0,
101
+ 10.5,
102
  12.699999809265137,
 
103
  11.0,
104
+ 10.800000190734863,
105
+ 10.5,
 
106
  10.300000190734863,
107
  10.300000190734863,
108
+ 9.399999618530273,
 
 
 
109
  11.0,
 
110
  9.699999809265137,
111
+ 10.300000190734863,
112
  12.699999809265137,
 
 
113
  11.0,
114
  9.399999618530273,
115
+ 11.100000381469727,
116
+ 9.600000381469727,
117
+ 10.5,
118
+ 10.0,
119
+ 10.5,
120
+ 9.600000381469727,
121
+ 11.600000381469727,
122
+ 11.0,
123
  8.699999809265137,
 
124
  11.0,
125
+ 11.0,
126
+ 9.699999809265137,
127
  10.800000190734863,
 
 
 
128
  9.399999618530273,
 
 
129
  9.399999618530273,
130
  10.899999618530273,
131
  10.199999809265137,
132
+ 8.699999809265137,
133
+ 10.399999618530273,
134
+ 9.399999618530273,
135
+ 9.0,
136
  11.100000381469727,
 
137
  8.699999809265137,
138
+ 10.300000190734863,
139
  11.600000381469727,
140
+ 10.0,
141
+ 10.899999618530273,
142
+ 11.0,
143
+ 9.699999809265137,
144
+ 10.0,
145
+ 11.100000381469727,
146
+ 9.699999809265137,
147
+ 10.5,
148
+ 8.699999809265137,
149
+ 9.600000381469727,
150
+ 10.399999618530273,
151
+ 11.0,
152
+ 11.100000381469727,
153
+ 10.800000190734863,
154
+ 9.0,
155
+ 10.0,
156
+ 11.0,
157
+ 10.300000190734863,
158
  9.399999618530273,
159
  9.699999809265137,
160
+ 12.699999809265137,
161
+ 9.0
162
  ],
163
  "moe_predictions": [
164
+ 11.608917236328125,
165
+ 9.741426467895508,
166
+ 9.461359024047852,
167
+ 10.487305641174316,
168
+ 10.319334983825684,
169
+ 8.653582572937012,
170
+ 9.749049186706543,
171
+ 9.319536209106445,
172
+ 10.338312149047852,
173
+ 12.966812133789062,
174
+ 11.055685043334961,
175
+ 13.093341827392578,
176
+ 11.134803771972656,
177
+ 13.054267883300781,
178
+ 9.044750213623047,
179
+ 12.060381889343262,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  11.135326385498047,
181
+ 11.014484405517578,
182
+ 10.392723083496094,
183
+ 11.37826156616211,
184
+ 10.060087203979492,
185
+ 11.353907585144043,
186
+ 10.72860050201416,
187
+ 9.777619361877441,
188
+ 9.150984764099121,
189
+ 10.573850631713867,
190
+ 10.46796989440918,
191
+ 10.479241371154785,
192
+ 9.75227165222168,
193
+ 10.527164459228516,
194
+ 13.064764976501465,
195
+ 9.474852561950684,
196
+ 9.668087005615234,
197
+ 9.823186874389648,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  9.318197250366211,
199
+ 10.484042167663574,
200
+ 9.54578971862793,
201
+ 10.413134574890137,
202
+ 11.154340744018555,
203
+ 13.079666137695312,
204
+ 9.558942794799805,
205
+ 11.153170585632324,
206
+ 8.779823303222656,
207
+ 11.030976295471191,
208
+ 10.56929874420166,
209
+ 11.015460968017578,
210
+ 11.146952629089355,
211
  11.136444091796875,
212
+ 10.356411933898926,
213
+ 11.381966590881348,
214
+ 11.53759765625,
215
+ 9.69221305847168,
216
+ 10.891069412231445,
217
+ 11.705709457397461,
218
+ 10.507513999938965,
219
+ 9.079387664794922,
220
+ 9.473494529724121,
221
+ 9.648874282836914,
222
+ 11.7305908203125,
223
+ 9.888289451599121,
224
+ 9.338244438171387,
225
+ 10.491485595703125,
226
+ 10.797355651855469,
227
+ 8.876679420471191,
228
+ 12.945722579956055,
229
+ 10.505922317504883,
230
+ 9.545509338378906,
231
+ 10.245137214660645,
232
+ 10.609914779663086,
233
+ 9.690855026245117,
234
+ 9.788698196411133,
235
+ 11.52328109741211,
236
  10.893503189086914,
237
+ 9.418478965759277,
238
+ 11.218090057373047,
239
+ 11.710685729980469,
240
+ 10.888498306274414,
241
+ 8.951180458068848,
242
+ 9.556252479553223,
243
+ 12.008685111999512,
244
+ 11.203088760375977,
245
+ 9.37525463104248,
246
+ 9.686023712158203,
247
+ 11.137346267700195,
248
+ 10.356472969055176,
249
+ 9.560345649719238,
250
+ 11.539974212646484,
251
+ 11.24638557434082,
252
+ 9.592302322387695,
253
+ 8.74775505065918,
254
+ 10.552587509155273,
255
+ 10.164124488830566,
256
+ 10.536083221435547,
257
+ 10.612926483154297,
258
+ 10.58446979522705,
259
+ 11.010236740112305,
260
+ 10.861842155456543,
261
+ 12.990730285644531,
262
+ 11.20481014251709,
263
+ 11.203653335571289,
264
+ 10.694746017456055,
265
+ 10.50363826751709,
266
+ 10.627494812011719,
267
+ 9.526586532592773,
268
+ 11.152572631835938,
269
+ 9.644195556640625,
270
+ 10.509271621704102,
271
+ 12.95602035522461,
272
+ 11.141549110412598,
273
+ 9.429258346557617,
274
+ 11.232805252075195,
275
+ 9.700346946716309,
276
+ 10.68587875366211,
277
+ 10.229130744934082,
278
+ 10.715401649475098,
279
+ 9.776931762695312,
280
+ 11.698503494262695,
281
+ 10.898889541625977,
282
+ 8.892599105834961,
283
+ 11.125198364257812,
284
+ 10.992132186889648,
285
+ 9.235944747924805,
286
+ 11.17458724975586,
287
+ 9.79542064666748,
288
+ 9.371628761291504,
289
+ 11.255684852600098,
290
+ 10.605937957763672,
291
+ 9.060511589050293,
292
+ 10.476083755493164,
293
  9.548912048339844,
294
+ 9.350934982299805,
295
+ 11.556468963623047,
296
+ 8.781621932983398,
297
+ 10.605558395385742,
298
+ 11.742720603942871,
299
+ 10.156621932983398,
300
+ 11.166330337524414,
301
+ 11.228448867797852,
302
+ 9.908857345581055,
303
+ 10.191004753112793,
304
+ 11.530580520629883,
305
+ 9.941258430480957,
306
+ 10.884675025939941,
307
+ 9.074652671813965,
308
+ 9.77452278137207,
309
+ 10.470745086669922,
310
+ 11.077189445495605,
311
+ 11.514217376708984,
312
+ 11.264935493469238,
313
+ 9.093061447143555,
314
+ 10.090995788574219,
315
+ 11.240152359008789,
316
+ 10.398412704467773,
317
+ 9.46157169342041,
318
+ 9.580022811889648,
319
+ 13.064597129821777,
320
+ 9.182878494262695
321
  ],
322
  "individual_predictions": {
323
  "efficientnet_b3_transformer": [
324
+ 11.339303016662598,
325
+ 9.160046577453613,
326
+ 9.141955375671387,
327
+ 10.123801231384277,
328
+ 9.975425720214844,
329
+ 8.028714179992676,
330
+ 9.226607322692871,
331
+ 9.101340293884277,
332
+ 10.290902137756348,
333
+ 13.305447578430176,
334
+ 10.197638511657715,
335
+ 13.537657737731934,
336
+ 10.54432201385498,
337
+ 13.52890396118164,
338
+ 8.314106941223145,
339
+ 11.723322868347168,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  10.548056602478027,
341
+ 10.766114234924316,
342
+ 10.293607711791992,
343
+ 10.927001953125,
344
+ 9.803337097167969,
345
+ 11.071410179138184,
346
+ 10.097264289855957,
347
+ 9.165467262268066,
348
+ 8.166515350341797,
349
+ 10.0133056640625,
350
+ 10.137511253356934,
351
+ 9.890531539916992,
352
+ 9.145689964294434,
353
+ 10.485107421875,
354
+ 13.939330101013184,
355
+ 9.209654808044434,
356
+ 9.333880424499512,
357
+ 9.570420265197754,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  8.978877067565918,
359
+ 10.143651962280273,
360
+ 9.263312339782715,
361
+ 10.041259765625,
362
+ 10.457343101501465,
363
+ 13.546338081359863,
364
+ 9.288726806640625,
365
+ 10.456021308898926,
366
+ 8.490445137023926,
367
+ 10.460243225097656,
368
+ 10.012919425964355,
369
+ 11.114548683166504,
370
+ 10.548954963684082,
371
  10.77907657623291,
372
+ 10.15251350402832,
373
+ 10.923174858093262,
374
+ 11.296109199523926,
375
+ 9.368083000183105,
376
+ 10.545008659362793,
377
+ 11.159947395324707,
378
+ 10.038147926330566,
379
+ 8.497147560119629,
380
+ 9.207659721374512,
381
+ 9.170109748840332,
382
+ 11.221556663513184,
383
+ 9.174721717834473,
384
+ 8.752867698669434,
385
+ 10.336318969726562,
386
+ 10.116740226745605,
387
+ 8.14444637298584,
388
+ 13.291146278381348,
389
+ 10.12454891204834,
390
+ 9.043634414672852,
391
+ 9.82880687713623,
392
+ 9.841523170471191,
393
+ 9.366087913513184,
394
+ 9.41323471069336,
395
+ 10.771563529968262,
396
  10.565585136413574,
397
+ 8.822549819946289,
398
+ 11.126303672790527,
399
+ 11.17785358428955,
400
+ 10.847918510437012,
401
+ 8.105504035949707,
402
+ 9.042283058166504,
403
+ 11.476466178894043,
404
+ 10.669010162353516,
405
+ 8.949850082397461,
406
+ 9.371846199035645,
407
+ 11.209992408752441,
408
+ 10.284793853759766,
409
+ 8.732993125915527,
410
+ 10.731574058532715,
411
+ 10.698369026184082,
412
+ 8.777587890625,
413
+ 8.237317085266113,
414
+ 10.505911827087402,
415
+ 9.840256690979004,
416
+ 10.486929893493652,
417
+ 10.697690963745117,
418
+ 10.00699520111084,
419
+ 10.793766975402832,
420
+ 10.49045467376709,
421
+ 14.00195026397705,
422
+ 10.92188835144043,
423
+ 11.09973430633545,
424
+ 10.037339210510254,
425
+ 10.13139533996582,
426
+ 10.012660026550293,
427
+ 8.973554611206055,
428
+ 10.546631813049316,
429
+ 9.004876136779785,
430
+ 10.006653785705566,
431
+ 13.256916999816895,
432
+ 10.455121994018555,
433
+ 9.511059761047363,
434
+ 10.602723121643066,
435
+ 9.374435424804688,
436
+ 10.019323348999023,
437
+ 9.987650871276855,
438
+ 10.076990127563477,
439
+ 9.511448860168457,
440
+ 11.17209243774414,
441
+ 10.794194221496582,
442
+ 8.425066947937012,
443
+ 10.724698066711426,
444
+ 10.763283729553223,
445
+ 8.875535011291504,
446
+ 10.71423625946045,
447
+ 9.314862251281738,
448
+ 8.985882759094238,
449
+ 11.253849983215332,
450
+ 9.853181838989258,
451
+ 8.331829071044922,
452
+ 10.341578483581543,
453
  9.269180297851562,
454
+ 8.643234252929688,
455
+ 11.096152305603027,
456
+ 8.507393836975098,
457
+ 10.021732330322266,
458
+ 11.22731876373291,
459
+ 9.608407974243164,
460
+ 10.589388847351074,
461
+ 10.327948570251465,
462
+ 9.291131019592285,
463
+ 9.988767623901367,
464
+ 11.252240180969238,
465
+ 9.299224853515625,
466
+ 10.54757022857666,
467
+ 8.331646919250488,
468
+ 9.259908676147461,
469
+ 10.290452003479004,
470
+ 10.381683349609375,
471
+ 11.2520170211792,
472
+ 10.613112449645996,
473
+ 8.515460014343262,
474
+ 9.602897644042969,
475
+ 10.69603443145752,
476
+ 9.902280807495117,
477
+ 9.069375038146973,
478
+ 8.936785697937012,
479
+ 13.536062240600586,
480
+ 8.750259399414062
481
  ],
482
  "efficientnet_b0_transformer": [
483
+ 11.985595703125,
484
+ 10.36156177520752,
485
+ 9.784621238708496,
486
+ 10.845489501953125,
487
+ 10.589117050170898,
488
+ 8.855218887329102,
489
+ 9.884342193603516,
490
+ 9.477718353271484,
491
+ 10.49835205078125,
492
+ 13.499242782592773,
493
+ 11.343120574951172,
494
+ 12.953442573547363,
495
+ 11.393060684204102,
496
+ 13.101705551147461,
497
+ 9.225322723388672,
498
+ 12.007133483886719,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  11.39995002746582,
500
+ 11.21767807006836,
501
+ 10.430135726928711,
502
+ 11.690134048461914,
503
+ 9.993916511535645,
504
+ 11.647773742675781,
505
+ 11.226818084716797,
506
+ 9.928828239440918,
507
+ 9.873790740966797,
508
+ 10.440296173095703,
509
+ 10.689691543579102,
510
+ 10.651750564575195,
511
+ 10.012208938598633,
512
+ 10.37149429321289,
513
+ 13.044389724731445,
514
+ 9.652556419372559,
515
+ 9.76612377166748,
516
+ 10.0393705368042,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  9.647960662841797,
518
+ 10.770721435546875,
519
+ 9.662923812866211,
520
+ 10.81811809539795,
521
+ 11.476736068725586,
522
+ 12.968710899353027,
523
+ 9.738420486450195,
524
+ 11.47037124633789,
525
+ 8.632101058959961,
526
+ 11.185698509216309,
527
+ 11.003364562988281,
528
+ 10.891831398010254,
529
+ 11.404769897460938,
530
  11.49870777130127,
531
+ 10.478404998779297,
532
+ 11.697516441345215,
533
+ 11.908629417419434,
534
+ 9.77428150177002,
535
+ 11.29122543334961,
536
+ 11.484548568725586,
537
+ 10.938173294067383,
538
+ 9.57394027709961,
539
+ 9.642441749572754,
540
+ 9.688291549682617,
541
+ 11.521190643310547,
542
+ 10.456705093383789,
543
+ 9.57772159576416,
544
+ 10.563447952270508,
545
+ 11.08597469329834,
546
+ 9.118146896362305,
547
+ 13.47038745880127,
548
+ 10.913671493530273,
549
+ 9.929350852966309,
550
+ 10.526603698730469,
551
+ 10.606958389282227,
552
+ 9.764165878295898,
553
+ 10.084386825561523,
554
+ 11.6922607421875,
555
  11.027682304382324,
556
+ 9.797820091247559,
557
+ 11.257314682006836,
558
+ 12.267354011535645,
559
+ 10.73689079284668,
560
+ 9.154512405395508,
561
+ 9.921629905700684,
562
+ 12.784350395202637,
563
+ 11.669108390808105,
564
+ 9.659965515136719,
565
+ 9.74787712097168,
566
+ 11.229676246643066,
567
+ 10.430813789367676,
568
+ 9.788354873657227,
569
+ 11.721125602722168,
570
+ 11.825557708740234,
571
+ 9.755647659301758,
572
+ 8.926406860351562,
573
+ 10.426839828491211,
574
+ 10.43403148651123,
575
+ 10.416683197021484,
576
+ 10.326852798461914,
577
+ 10.440661430358887,
578
+ 11.215356826782227,
579
+ 11.223287582397461,
580
+ 13.147810935974121,
581
+ 11.27365779876709,
582
+ 11.516763687133789,
583
+ 11.006742477416992,
584
+ 10.878545761108398,
585
+ 11.415714263916016,
586
+ 9.696914672851562,
587
+ 11.417068481445312,
588
+ 9.799717903137207,
589
+ 11.379979133605957,
590
+ 13.502660751342773,
591
+ 11.465564727783203,
592
+ 9.251531600952148,
593
+ 11.778385162353516,
594
+ 9.734674453735352,
595
+ 10.932029724121094,
596
+ 10.582185745239258,
597
+ 11.00518798828125,
598
+ 9.725820541381836,
599
+ 12.247865676879883,
600
+ 10.734901428222656,
601
+ 8.928577423095703,
602
+ 11.397771835327148,
603
+ 11.13377571105957,
604
+ 9.139379501342773,
605
+ 11.641318321228027,
606
+ 10.129936218261719,
607
+ 9.684531211853027,
608
+ 11.295875549316406,
609
+ 10.605236053466797,
610
+ 9.146963119506836,
611
+ 10.568946838378906,
612
  9.758452415466309,
613
+ 9.650102615356445,
614
+ 11.9966402053833,
615
+ 8.678672790527344,
616
+ 11.002718925476074,
617
+ 11.540517807006836,
618
+ 10.397274017333984,
619
+ 11.197608947753906,
620
+ 11.805412292480469,
621
+ 9.95970630645752,
622
+ 10.514514923095703,
623
+ 11.889755249023438,
624
+ 10.047914505004883,
625
+ 10.953085899353027,
626
+ 9.211109161376953,
627
+ 9.910860061645508,
628
+ 10.400971412658691,
629
+ 11.485074996948242,
630
+ 11.828522682189941,
631
+ 11.306056022644043,
632
+ 9.125839233398438,
633
+ 10.399169921875,
634
+ 11.806390762329102,
635
+ 10.55948543548584,
636
+ 9.855262756347656,
637
+ 9.390632629394531,
638
+ 12.962108612060547,
639
+ 9.551152229309082
640
  ],
641
  "resnet50_transformer": [
642
+ 11.501852989196777,
643
+ 9.702670097351074,
644
+ 9.457501411437988,
645
+ 10.49262523651123,
646
+ 10.393461227416992,
647
+ 9.076814651489258,
648
+ 10.136197090148926,
649
+ 9.379548072814941,
650
+ 10.225680351257324,
651
+ 12.095745086669922,
652
+ 11.626296043395996,
653
+ 12.788922309875488,
654
+ 11.467028617858887,
655
+ 12.532193183898926,
656
+ 9.594820022583008,
657
+ 12.450687408447266,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  11.45797348022461,
659
+ 11.059659957885742,
660
+ 10.454424858093262,
661
+ 11.51764965057373,
662
+ 10.38300609588623,
663
+ 11.342537879943848,
664
+ 10.86171817779541,
665
+ 10.238561630249023,
666
+ 9.412647247314453,
667
+ 11.267950057983398,
668
+ 10.576706886291504,
669
+ 10.895441055297852,
670
+ 10.098917007446289,
671
+ 10.72489070892334,
672
+ 12.210573196411133,
673
+ 9.562345504760742,
674
+ 9.904257774353027,
675
+ 9.859766960144043,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  9.327754020690918,
677
+ 10.537753105163574,
678
+ 9.71113109588623,
679
+ 10.380024909973145,
680
+ 11.528943061828613,
681
+ 12.723946571350098,
682
+ 9.649679183959961,
683
+ 11.53311824798584,
684
+ 9.216923713684082,
685
+ 11.446986198425293,
686
+ 10.691610336303711,
687
+ 11.040003776550293,
688
+ 11.48713207244873,
689
  11.131546974182129,
690
+ 10.438315391540527,
691
+ 11.525206565856934,
692
+ 11.408055305480957,
693
+ 9.934273719787598,
694
+ 10.836973190307617,
695
+ 12.472630500793457,
696
+ 10.546219825744629,
697
+ 9.167073249816895,
698
+ 9.570382118225098,
699
+ 10.088221549987793,
700
+ 12.44902515411377,
701
+ 10.033440589904785,
702
+ 9.684144020080566,
703
+ 10.574688911437988,
704
+ 11.189352989196777,
705
+ 9.367444038391113,
706
+ 12.075634956359863,
707
+ 10.479545593261719,
708
+ 9.663542747497559,
709
+ 10.379999160766602,
710
+ 11.381260871887207,
711
+ 9.942309379577637,
712
+ 9.8684720993042,
713
+ 12.106017112731934,
714
  11.08724308013916,
715
+ 9.635066986083984,
716
+ 11.270651817321777,
717
+ 11.686848640441895,
718
+ 11.080682754516602,
719
+ 9.593523979187012,
720
+ 9.70484447479248,
721
+ 11.765238761901855,
722
+ 11.271148681640625,
723
+ 9.515948295593262,
724
+ 9.938346862792969,
725
+ 10.972367286682129,
726
+ 10.35381031036377,
727
+ 10.159687995910645,
728
+ 12.167220115661621,
729
+ 11.215229034423828,
730
+ 10.243671417236328,
731
+ 9.07953929901123,
732
+ 10.72500991821289,
733
+ 10.218084335327148,
734
+ 10.704636573791504,
735
+ 10.81423568725586,
736
+ 11.305752754211426,
737
+ 11.021586418151855,
738
+ 10.871784210205078,
739
+ 11.822425842285156,
740
+ 11.418882369995117,
741
+ 10.994462013244629,
742
+ 11.040154457092285,
743
+ 10.500971794128418,
744
+ 10.454109191894531,
745
+ 9.909290313720703,
746
+ 11.494017601013184,
747
+ 10.127991676330566,
748
+ 10.141182899475098,
749
+ 12.10848331451416,
750
+ 11.503960609436035,
751
+ 9.52518367767334,
752
+ 11.317305564880371,
753
+ 9.991930961608887,
754
+ 11.106281280517578,
755
+ 10.117555618286133,
756
+ 11.064026832580566,
757
+ 10.093523979187012,
758
+ 11.675549507141113,
759
+ 11.167573928833008,
760
+ 9.324151039123535,
761
+ 11.253122329711914,
762
+ 11.079337120056152,
763
+ 9.692917823791504,
764
+ 11.168207168579102,
765
+ 9.941462516784668,
766
+ 9.444470405578613,
767
+ 11.217329025268555,
768
+ 11.359396934509277,
769
+ 9.702740669250488,
770
+ 10.517725944519043,
771
  9.619100570678711,
772
+ 9.759465217590332,
773
+ 11.576613426208496,
774
+ 9.158799171447754,
775
+ 10.792224884033203,
776
+ 12.46032428741455,
777
+ 10.46418285369873,
778
+ 11.711991310119629,
779
+ 11.551984786987305,
780
+ 10.47573471069336,
781
+ 10.069729804992676,
782
+ 11.449746131896973,
783
+ 10.476634979248047,
784
+ 11.153368949890137,
785
+ 9.68120002746582,
786
+ 10.152800559997559,
787
+ 10.720810890197754,
788
+ 11.3648099899292,
789
+ 11.462113380432129,
790
+ 11.875636100769043,
791
+ 9.637883186340332,
792
+ 10.270918846130371,
793
+ 11.218029975891113,
794
+ 10.733470916748047,
795
+ 9.460077285766602,
796
+ 10.412650108337402,
797
+ 12.6956205368042,
798
+ 9.247221946716309
799
  ]
800
  }
801
  }
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- torch>=2.0.0
2
- torchaudio>=2.0.0
3
- torchvision>=0.15.0
4
- gradio>=3.50.0
5
- numpy>=1.20.0
6
- pillow>=9.0.0
7
- tensorboard>=2.12.0
8
  pydantic==2.10.6
 
1
+ torch==2.6.0
2
+ torchaudio==2.6.0
3
+ torchvision==0.21.0
4
+ gradio==5.9.1
5
+ numpy==2.0.2
6
+ pillow==10.4.0
7
+ tensorboard==2.19.0
8
  pydantic==2.10.6
test_moe_model.py CHANGED
@@ -37,7 +37,7 @@ TOP_MODELS = [
37
 
38
  # Define class for the MoE model
39
  class WatermelonMoEModel(torch.nn.Module):
40
- def __init__(self, model_configs, model_dir="test_models", weights=None):
41
  """
42
  Mixture of Experts model that combines multiple backbone models.
43
 
@@ -92,13 +92,16 @@ class WatermelonMoEModel(torch.nn.Module):
92
  with torch.no_grad():
93
  for i, model in enumerate(self.models):
94
  output = model(mfcc, image)
 
95
  outputs.append(output * self.weights[i])
96
 
97
  # Return weighted average
98
- return torch.sum(torch.stack(outputs), dim=0)
 
 
99
 
100
 
101
- def evaluate_moe_model(data_dir, model_dir="test_models", weights=None):
102
  """
103
  Evaluate the MoE model on the test set.
104
  """
@@ -153,10 +156,12 @@ def evaluate_moe_model(data_dir, model_dir="test_models", weights=None):
153
  model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
154
  output = model(mfcc, image)
155
  individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
 
156
 
157
  # Get MoE prediction
158
  output = moe_model(mfcc, image)
159
  moe_predictions.extend(output.view(-1).cpu().numpy())
 
160
 
161
  # Store true labels
162
  label = label.view(-1, 1).float()
@@ -243,7 +248,7 @@ if __name__ == "__main__":
243
  parser.add_argument(
244
  "--model_dir",
245
  type=str,
246
- default="test_models",
247
  help="Directory containing model checkpoints"
248
  )
249
  parser.add_argument(
 
37
 
38
  # Define class for the MoE model
39
  class WatermelonMoEModel(torch.nn.Module):
40
+ def __init__(self, model_configs, model_dir="models", weights=None):
41
  """
42
  Mixture of Experts model that combines multiple backbone models.
43
 
 
92
  with torch.no_grad():
93
  for i, model in enumerate(self.models):
94
  output = model(mfcc, image)
95
+ print(f"DEBUG: Model {i} output: {output}")
96
  outputs.append(output * self.weights[i])
97
 
98
  # Return weighted average
99
+ final_output = torch.sum(torch.stack(outputs), dim=0)
100
+ print(f"DEBUG: Raw prediction: {final_output}")
101
+ return final_output
102
 
103
 
104
+ def evaluate_moe_model(data_dir, model_dir="models", weights=None):
105
  """
106
  Evaluate the MoE model on the test set.
107
  """
 
156
  model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
157
  output = model(mfcc, image)
158
  individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
159
+ print(f"DEBUG: Model {j} output: {output}")
160
 
161
  # Get MoE prediction
162
  output = moe_model(mfcc, image)
163
  moe_predictions.extend(output.view(-1).cpu().numpy())
164
+ print(f"DEBUG: MoE prediction: {output}")
165
 
166
  # Store true labels
167
  label = label.view(-1, 1).float()
 
248
  parser.add_argument(
249
  "--model_dir",
250
  type=str,
251
+ default="models",
252
  help="Directory containing model checkpoints"
253
  )
254
  parser.add_argument(