Xalphinions commited on
Commit
6f4e394
·
verified ·
1 Parent(s): 83c2e18

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -35,3 +35,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
37
  temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
37
  temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
38
+ models/.nfs00000001a1a17512003726ad filter=lfs diff=lfs merge=lfs -text
39
+ models/.nfs00000001a234d9cd003726ac filter=lfs diff=lfs merge=lfs -text
40
+ models/.nfs00000001a2a11ea9003726ae filter=lfs diff=lfs merge=lfs -text
.nfs00000001a2244b30003726a6 ADDED
@@ -0,0 +1 @@
 
 
1
+
.nfs00000001a2b1089c003726a7 ADDED
@@ -0,0 +1 @@
 
 
1
+
__pycache__/evaluate_backbones.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
__pycache__/preprocess.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
app.py CHANGED
@@ -6,21 +6,82 @@ import gradio as gr
6
  import torchaudio
7
  import torchvision
8
  import spaces
9
-
10
- # # Import Gradio Spaces GPU decorator
11
- # try:
12
- # from gradio import spaces
13
- # HAS_SPACES = True
14
- # print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
15
- # except ImportError:
16
- # HAS_SPACES = False
17
- # print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
18
 
19
  # Add parent directory to path to import preprocess functions
20
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
 
22
- # Import functions from infer_watermelon.py and train_watermelon for the model
23
- from train_watermelon import WatermelonModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Modified version of process_audio_data specifically for the app to handle various tensor shapes
26
  def app_process_audio_data(waveform, sample_rate):
@@ -76,15 +137,12 @@ def app_process_audio_data(waveform, sample_rate):
76
  print(traceback.format_exc())
77
  return None
78
 
79
- # Similarly for images, but let's import the original one
80
- from preprocess import process_image_data
81
-
82
- # Using the decorator directly on the function definition
83
  @spaces.GPU
84
- def predict_sugar_content(audio, image, model_path):
85
- """Function with GPU acceleration to predict watermelon sugar content in Brix"""
86
  try:
87
- # Now check CUDA availability inside the GPU-decorated function
88
  if torch.cuda.is_available():
89
  device = torch.device("cuda")
90
  print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
@@ -92,11 +150,11 @@ def predict_sugar_content(audio, image, model_path):
92
  device = torch.device("cpu")
93
  print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
94
 
95
- # Load model inside the function to ensure it's on the correct device
96
- model = WatermelonModel().to(device)
97
- model.load_state_dict(torch.load(model_path, map_location=device))
98
- model.eval()
99
- print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
100
 
101
  # Debug information about input types
102
  print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
@@ -188,11 +246,11 @@ def predict_sugar_content(audio, image, model_path):
188
  processed_image = processed_image.unsqueeze(0).to(device)
189
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
190
 
191
- # Run inference
192
- print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
193
  if mfcc is not None and processed_image is not None:
194
  with torch.no_grad():
195
- brix_value = model(mfcc, processed_image)
196
  print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
197
  else:
198
  return "Error: Failed to process inputs. Please check the debug logs."
@@ -204,6 +262,12 @@ def predict_sugar_content(audio, image, model_path):
204
  # Create a header with the numerical result
205
  result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
206
 
 
 
 
 
 
 
207
  # Add Brix scale visualization
208
  result += "Sugar Content Scale (in °Brix):\n"
209
  result += "──────────────────────────────────\n"
@@ -257,22 +321,27 @@ def predict_sugar_content(audio, image, model_path):
257
  error_msg += traceback.format_exc()
258
  print(f"\033[91mERR!\033[0m: {error_msg}")
259
  return error_msg
260
-
261
- print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
262
-
263
 
264
- def create_app(model_path):
265
  """Create and launch the Gradio interface"""
266
  # Define the prediction function with model path
267
  def predict_fn(audio, image):
268
- return predict_sugar_content(audio, image, model_path)
269
 
270
  # Create Gradio interface
271
- with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
272
- gr.Markdown("# 🍉 Watermelon Sugar Content Predictor")
273
  gr.Markdown("""
274
  This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
275
 
 
 
 
 
 
 
 
 
276
  ## Instructions:
277
  1. Upload or record an audio of tapping the watermelon
278
  2. Upload or capture an image of the watermelon
@@ -286,7 +355,7 @@ def create_app(model_path):
286
  submit_btn = gr.Button("Predict Sugar Content", variant="primary")
287
 
288
  with gr.Column():
289
- output = gr.Textbox(label="Prediction Results", lines=12)
290
 
291
  submit_btn.click(
292
  fn=predict_fn,
@@ -302,6 +371,11 @@ def create_app(model_path):
302
  ## About Brix Measurement
303
  Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
304
  The average ripe watermelon has a Brix value between 9-11°.
 
 
 
 
 
305
  """)
306
 
307
  return interface
@@ -309,12 +383,12 @@ def create_app(model_path):
309
  if __name__ == "__main__":
310
  import argparse
311
 
312
- parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
313
  parser.add_argument(
314
- "--model_path",
315
  type=str,
316
- default="models/watermelon_model_final.pt",
317
- help="Path to the trained model file"
318
  )
319
  parser.add_argument(
320
  "--share",
@@ -326,18 +400,40 @@ if __name__ == "__main__":
326
  action="store_true",
327
  help="Enable verbose debug output"
328
  )
 
 
 
 
 
 
 
329
 
330
  args = parser.parse_args()
331
 
332
  if args.debug:
333
  print(f"\033[92mINFO\033[0m: Debug mode enabled")
334
 
335
- # Check if model exists
336
- if not os.path.exists(args.model_path):
337
- print(f"\033[91mERR!\033[0m: Model not found at {args.model_path}")
338
- print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
339
  sys.exit(1)
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  # Create and launch the app
342
- app = create_app(args.model_path)
343
  app.launch(share=args.share)
 
6
  import torchaudio
7
  import torchvision
8
  import spaces
9
+ import json
 
 
 
 
 
 
 
 
10
 
11
  # Add parent directory to path to import preprocess functions
12
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
 
14
+ # Import functions from preprocess and model definitions
15
+ from preprocess import process_image_data
16
+ from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
17
+
18
+ # Define the top-performing models based on evaluation
19
+ TOP_MODELS = [
20
+ {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
21
+ {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
22
+ {"image_backbone": "resnet50", "audio_backbone": "transformer"}
23
+ ]
24
+
25
+ # Define the MoE Model
26
+ class WatermelonMoEModel(torch.nn.Module):
27
+ def __init__(self, model_configs, model_dir="models", weights=None):
28
+ """
29
+ Mixture of Experts model that combines multiple backbone models.
30
+
31
+ Args:
32
+ model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
33
+ model_dir: Directory where model checkpoints are stored
34
+ weights: Optional list of weights for each model (None for equal weighting)
35
+ """
36
+ super(WatermelonMoEModel, self).__init__()
37
+ self.models = []
38
+ self.model_configs = model_configs
39
+
40
+ # Load each model
41
+ for config in model_configs:
42
+ img_backbone = config["image_backbone"]
43
+ audio_backbone = config["audio_backbone"]
44
+
45
+ # Initialize model
46
+ model = WatermelonModelModular(img_backbone, audio_backbone)
47
+
48
+ # Load weights
49
+ model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
50
+ if os.path.exists(model_path):
51
+ print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
52
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
53
+ else:
54
+ print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
55
+ continue
56
+
57
+ model.eval() # Set to evaluation mode
58
+ self.models.append(model)
59
+
60
+ # Set model weights (uniform by default)
61
+ if weights:
62
+ assert len(weights) == len(self.models), "Number of weights must match number of models"
63
+ self.weights = weights
64
+ else:
65
+ self.weights = [1.0 / len(self.models)] * len(self.models)
66
+
67
+ print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
68
+ print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
69
+
70
+ def forward(self, mfcc, image):
71
+ """
72
+ Forward pass through the MoE model.
73
+ Returns the weighted average of all model outputs.
74
+ """
75
+ outputs = []
76
+
77
+ # Get outputs from each model
78
+ with torch.no_grad():
79
+ for i, model in enumerate(self.models):
80
+ output = model(mfcc, image)
81
+ outputs.append(output * self.weights[i])
82
+
83
+ # Return weighted average
84
+ return torch.sum(torch.stack(outputs), dim=0)
85
 
86
  # Modified version of process_audio_data specifically for the app to handle various tensor shapes
87
  def app_process_audio_data(waveform, sample_rate):
 
137
  print(traceback.format_exc())
138
  return None
139
 
140
+ # Using the decorator for GPU acceleration
 
 
 
141
  @spaces.GPU
142
+ def predict_sugar_content(audio, image, model_dir="models", weights=None):
143
+ """Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
144
  try:
145
+ # Check CUDA availability inside the GPU-decorated function
146
  if torch.cuda.is_available():
147
  device = torch.device("cuda")
148
  print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
 
150
  device = torch.device("cpu")
151
  print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
152
 
153
+ # Load MoE model
154
+ moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
155
+ moe_model.to(device)
156
+ moe_model.eval()
157
+ print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
158
 
159
  # Debug information about input types
160
  print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
 
246
  processed_image = processed_image.unsqueeze(0).to(device)
247
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
248
 
249
+ # Run inference with MoE model
250
+ print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
251
  if mfcc is not None and processed_image is not None:
252
  with torch.no_grad():
253
+ brix_value = moe_model(mfcc, processed_image)
254
  print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
255
  else:
256
  return "Error: Failed to process inputs. Please check the debug logs."
 
262
  # Create a header with the numerical result
263
  result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
264
 
265
+ # Add extra info about the MoE model
266
+ result += "Using Ensemble of Top-3 Models:\n"
267
+ result += "- EfficientNet-B3 + Transformer\n"
268
+ result += "- EfficientNet-B0 + Transformer\n"
269
+ result += "- ResNet-50 + Transformer\n\n"
270
+
271
  # Add Brix scale visualization
272
  result += "Sugar Content Scale (in °Brix):\n"
273
  result += "──────────────────────────────────\n"
 
321
  error_msg += traceback.format_exc()
322
  print(f"\033[91mERR!\033[0m: {error_msg}")
323
  return error_msg
 
 
 
324
 
325
+ def create_app(model_dir="models", weights=None):
326
  """Create and launch the Gradio interface"""
327
  # Define the prediction function with model path
328
  def predict_fn(audio, image):
329
+ return predict_sugar_content(audio, image, model_dir, weights)
330
 
331
  # Create Gradio interface
332
+ with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
333
+ gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
334
  gr.Markdown("""
335
  This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
336
 
337
+ ## What's New
338
+ This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
339
+ - EfficientNet-B3 + Transformer
340
+ - EfficientNet-B0 + Transformer
341
+ - ResNet-50 + Transformer
342
+
343
+ The ensemble approach provides more accurate predictions than any single model!
344
+
345
  ## Instructions:
346
  1. Upload or record an audio of tapping the watermelon
347
  2. Upload or capture an image of the watermelon
 
355
  submit_btn = gr.Button("Predict Sugar Content", variant="primary")
356
 
357
  with gr.Column():
358
+ output = gr.Textbox(label="Prediction Results", lines=15)
359
 
360
  submit_btn.click(
361
  fn=predict_fn,
 
371
  ## About Brix Measurement
372
  Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
373
  The average ripe watermelon has a Brix value between 9-11°.
374
+
375
+ ## About the Mixture of Experts Model
376
+ This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
377
+ Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
378
+ better than any individual model (best individual model: ~0.36 MAE).
379
  """)
380
 
381
  return interface
 
383
  if __name__ == "__main__":
384
  import argparse
385
 
386
+ parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
387
  parser.add_argument(
388
+ "--model_dir",
389
  type=str,
390
+ default="models",
391
+ help="Directory containing the model checkpoints"
392
  )
393
  parser.add_argument(
394
  "--share",
 
400
  action="store_true",
401
  help="Enable verbose debug output"
402
  )
403
+ parser.add_argument(
404
+ "--weighting",
405
+ type=str,
406
+ choices=["uniform", "performance"],
407
+ default="uniform",
408
+ help="How to weight the models (uniform or based on performance)"
409
+ )
410
 
411
  args = parser.parse_args()
412
 
413
  if args.debug:
414
  print(f"\033[92mINFO\033[0m: Debug mode enabled")
415
 
416
+ # Check if model directory exists
417
+ if not os.path.exists(args.model_dir):
418
+ print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
 
419
  sys.exit(1)
420
 
421
+ # Determine weights based on argument
422
+ weights = None
423
+ if args.weighting == "performance":
424
+ # Weights inversely proportional to the MAE (better models get higher weights)
425
+ # These are the MAE values from the evaluation results
426
+ mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
427
+
428
+ # Convert to weights (inverse of MAE, normalized)
429
+ inverse_mae = [1/mae for mae in mae_values]
430
+ total = sum(inverse_mae)
431
+ weights = [val/total for val in inverse_mae]
432
+
433
+ print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
434
+ else:
435
+ print(f"\033[92mINFO\033[0m: Using uniform weights")
436
+
437
  # Create and launch the app
438
+ app = create_app(args.model_dir, weights)
439
  app.launch(share=args.share)
app_local_backup.py CHANGED
@@ -5,12 +5,22 @@ import numpy as np
5
  import gradio as gr
6
  import torchaudio
7
  import torchvision
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Add parent directory to path to import preprocess functions
10
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
 
12
- # Import functions from infer_watermelon.py
13
- from infer_watermelon import load_model
14
 
15
  # Modified version of process_audio_data specifically for the app to handle various tensor shapes
16
  def app_process_audio_data(waveform, sample_rate):
@@ -69,14 +79,25 @@ def app_process_audio_data(waveform, sample_rate):
69
  # Similarly for images, but let's import the original one
70
  from preprocess import process_image_data
71
 
72
- def init_model(model_path):
73
- """Initialize the model for inference"""
74
- model, device = load_model(model_path)
75
- return model, device
76
-
77
- def predict_sweetness(audio, image, model, device):
78
- """Predict sweetness of a watermelon from audio and image input"""
79
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Debug information about input types
81
  print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
82
  print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
@@ -97,7 +118,6 @@ def predict_sweetness(audio, image, model, device):
97
  print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
98
  elif isinstance(audio, str):
99
  # Direct path to audio file
100
- import torchaudio
101
  audio_data, sample_rate = torchaudio.load(audio)
102
  print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
103
  else:
@@ -111,9 +131,6 @@ def predict_sweetness(audio, image, model, device):
111
  temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
112
 
113
  # Import necessary libraries
114
- import torchaudio
115
- import torchvision
116
- import torchvision.transforms.functional as F
117
  from PIL import Image
118
 
119
  # Audio handling - direct processing from the data in memory
@@ -162,7 +179,7 @@ def predict_sweetness(audio, image, model, device):
162
  processed_image = process_image_data(image_tensor)
163
  print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
164
 
165
- # Add batch dimension for inference
166
  if mfcc is not None:
167
  mfcc = mfcc.unsqueeze(0).to(device)
168
  print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
@@ -172,31 +189,67 @@ def predict_sweetness(audio, image, model, device):
172
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
173
 
174
  # Run inference
175
- print(f"\033[92mDEBUG\033[0m: Running inference")
176
  if mfcc is not None and processed_image is not None:
177
  with torch.no_grad():
178
- sweetness = model(mfcc, processed_image)
179
- print(f"\033[92mDEBUG\033[0m: Prediction successful: {sweetness.item()}")
180
  else:
181
  return "Error: Failed to process inputs. Please check the debug logs."
182
 
183
- # Format the result
184
- if sweetness is not None:
185
- result = f"Predicted Sweetness: {sweetness.item():.2f}/13"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # Add a qualitative description
188
- if sweetness.item() < 9:
189
- result += "\n\nThis watermelon is not very sweet. You might want to choose another one."
190
- elif sweetness.item() < 10:
191
- result += "\n\nThis watermelon has moderate sweetness."
192
- elif sweetness.item() < 11:
193
- result += "\n\nThis watermelon is sweet! A good choice."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  else:
195
- result += "\n\nThis watermelon is very sweet! Excellent choice!"
196
 
197
  return result
198
  else:
199
- return "Error: Could not predict sweetness. Please try again with different inputs."
200
 
201
  except Exception as e:
202
  import traceback
@@ -204,36 +257,36 @@ def predict_sweetness(audio, image, model, device):
204
  error_msg += traceback.format_exc()
205
  print(f"\033[91mERR!\033[0m: {error_msg}")
206
  return error_msg
 
 
 
207
 
208
  def create_app(model_path):
209
  """Create and launch the Gradio interface"""
210
- # Initialize model
211
- model, device = init_model(model_path)
212
-
213
- # Define the prediction function with model and device
214
  def predict_fn(audio, image):
215
- return predict_sweetness(audio, image, model, device)
216
 
217
  # Create Gradio interface
218
- with gr.Blocks(title="Watermelon Sweetness Predictor") as interface:
219
- gr.Markdown("# 🍉 Watermelon Sweetness Predictor")
220
  gr.Markdown("""
221
- This app predicts the sweetness of a watermelon based on its sound and appearance.
222
 
223
  ## Instructions:
224
  1. Upload or record an audio of tapping the watermelon
225
  2. Upload or capture an image of the watermelon
226
- 3. Click 'Submit' to get the predicted sweetness
227
  """)
228
 
229
  with gr.Row():
230
  with gr.Column():
231
  audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
232
  image_input = gr.Image(label="Upload or Capture Image")
233
- submit_btn = gr.Button("Predict Sweetness", variant="primary")
234
 
235
  with gr.Column():
236
- output = gr.Textbox(label="Prediction Results", lines=6)
237
 
238
  submit_btn.click(
239
  fn=predict_fn,
@@ -242,13 +295,13 @@ def create_app(model_path):
242
  )
243
 
244
  gr.Markdown("""
245
- ## How it works
246
-
247
- The app uses a deep learning model that combines:
248
- - Audio analysis using MFCC features and LSTM neural network
249
- - Image analysis using ResNet-50 convolutional neural network
250
 
251
- The model was trained on a dataset of watermelons with known sweetness values.
 
 
252
  """)
253
 
254
  return interface
@@ -256,7 +309,7 @@ def create_app(model_path):
256
  if __name__ == "__main__":
257
  import argparse
258
 
259
- parser = argparse.ArgumentParser(description="Watermelon Sweetness Prediction App")
260
  parser.add_argument(
261
  "--model_path",
262
  type=str,
 
5
  import gradio as gr
6
  import torchaudio
7
  import torchvision
8
+ import spaces
9
+
10
+ # # Import Gradio Spaces GPU decorator
11
+ # try:
12
+ # from gradio import spaces
13
+ # HAS_SPACES = True
14
+ # print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
15
+ # except ImportError:
16
+ # HAS_SPACES = False
17
+ # print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
18
 
19
  # Add parent directory to path to import preprocess functions
20
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
 
22
+ # Import functions from infer_watermelon.py and train_watermelon for the model
23
+ from train_watermelon import WatermelonModel
24
 
25
  # Modified version of process_audio_data specifically for the app to handle various tensor shapes
26
  def app_process_audio_data(waveform, sample_rate):
 
79
  # Similarly for images, but let's import the original one
80
  from preprocess import process_image_data
81
 
82
+ # Using the decorator directly on the function definition
83
+ @spaces.GPU
84
+ def predict_sugar_content(audio, image, model_path):
85
+ """Function with GPU acceleration to predict watermelon sugar content in Brix"""
 
 
 
86
  try:
87
+ # Now check CUDA availability inside the GPU-decorated function
88
+ if torch.cuda.is_available():
89
+ device = torch.device("cuda")
90
+ print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
91
+ else:
92
+ device = torch.device("cpu")
93
+ print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
94
+
95
+ # Load model inside the function to ensure it's on the correct device
96
+ model = WatermelonModel().to(device)
97
+ model.load_state_dict(torch.load(model_path, map_location=device))
98
+ model.eval()
99
+ print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
100
+
101
  # Debug information about input types
102
  print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
103
  print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
 
118
  print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
119
  elif isinstance(audio, str):
120
  # Direct path to audio file
 
121
  audio_data, sample_rate = torchaudio.load(audio)
122
  print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
123
  else:
 
131
  temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
132
 
133
  # Import necessary libraries
 
 
 
134
  from PIL import Image
135
 
136
  # Audio handling - direct processing from the data in memory
 
179
  processed_image = process_image_data(image_tensor)
180
  print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
181
 
182
+ # Add batch dimension for inference and move to device
183
  if mfcc is not None:
184
  mfcc = mfcc.unsqueeze(0).to(device)
185
  print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
 
189
  print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
190
 
191
  # Run inference
192
+ print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
193
  if mfcc is not None and processed_image is not None:
194
  with torch.no_grad():
195
+ brix_value = model(mfcc, processed_image)
196
+ print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
197
  else:
198
  return "Error: Failed to process inputs. Please check the debug logs."
199
 
200
+ # Format the result with a range display
201
+ if brix_value is not None:
202
+ brix_score = brix_value.item()
203
+
204
+ # Create a header with the numerical result
205
+ result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
206
+
207
+ # Add Brix scale visualization
208
+ result += "Sugar Content Scale (in °Brix):\n"
209
+ result += "──────────────────────────────────\n"
210
+
211
+ # Create the scale display with Brix ranges
212
+ scale_ranges = [
213
+ (0, 8, "Low Sugar (< 8° Brix)"),
214
+ (8, 9, "Mild Sweetness (8-9° Brix)"),
215
+ (9, 10, "Medium Sweetness (9-10° Brix)"),
216
+ (10, 11, "Sweet (10-11° Brix)"),
217
+ (11, 13, "Very Sweet (11-13° Brix)")
218
+ ]
219
 
220
+ # Find which category the prediction falls into
221
+ user_category = None
222
+ for min_val, max_val, category_name in scale_ranges:
223
+ if min_val <= brix_score < max_val:
224
+ user_category = category_name
225
+ break
226
+ if brix_score >= scale_ranges[-1][0]: # Handle edge case
227
+ user_category = scale_ranges[-1][2]
228
+
229
+ # Display the scale with the user's result highlighted
230
+ for min_val, max_val, category_name in scale_ranges:
231
+ if category_name == user_category:
232
+ result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
233
+ else:
234
+ result += f" {min_val}-{max_val}: {category_name}\n"
235
+
236
+ result += "──────────────────────────────────\n\n"
237
+
238
+ # Add assessment of the watermelon's sugar content
239
+ if brix_score < 8:
240
+ result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
241
+ elif brix_score < 9:
242
+ result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
243
+ elif brix_score < 10:
244
+ result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
245
+ elif brix_score < 11:
246
+ result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
247
  else:
248
+ result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
249
 
250
  return result
251
  else:
252
+ return "Error: Could not predict sugar content. Please try again with different inputs."
253
 
254
  except Exception as e:
255
  import traceback
 
257
  error_msg += traceback.format_exc()
258
  print(f"\033[91mERR!\033[0m: {error_msg}")
259
  return error_msg
260
+
261
+ print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
262
+
263
 
264
  def create_app(model_path):
265
  """Create and launch the Gradio interface"""
266
+ # Define the prediction function with model path
 
 
 
267
  def predict_fn(audio, image):
268
+ return predict_sugar_content(audio, image, model_path)
269
 
270
  # Create Gradio interface
271
+ with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
272
+ gr.Markdown("# 🍉 Watermelon Sugar Content Predictor")
273
  gr.Markdown("""
274
+ This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
275
 
276
  ## Instructions:
277
  1. Upload or record an audio of tapping the watermelon
278
  2. Upload or capture an image of the watermelon
279
+ 3. Click 'Predict' to get the sugar content estimation
280
  """)
281
 
282
  with gr.Row():
283
  with gr.Column():
284
  audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
285
  image_input = gr.Image(label="Upload or Capture Image")
286
+ submit_btn = gr.Button("Predict Sugar Content", variant="primary")
287
 
288
  with gr.Column():
289
+ output = gr.Textbox(label="Prediction Results", lines=12)
290
 
291
  submit_btn.click(
292
  fn=predict_fn,
 
295
  )
296
 
297
  gr.Markdown("""
298
+ ## Tips for best results
299
+ - For audio: Tap the watermelon with your knuckle and record the sound
300
+ - For image: Take a clear photo of the whole watermelon in good lighting
 
 
301
 
302
+ ## About Brix Measurement
303
+ Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
304
+ The average ripe watermelon has a Brix value between 9-11°.
305
  """)
306
 
307
  return interface
 
309
  if __name__ == "__main__":
310
  import argparse
311
 
312
+ parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
313
  parser.add_argument(
314
  "--model_path",
315
  type=str,
app_moe.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import torchaudio
7
+ import torchvision
8
+ import spaces
9
+ import json
10
+
11
+ # Add parent directory to path to import preprocess functions
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+
14
+ # Import functions from preprocess and model definitions
15
+ from preprocess import process_image_data
16
+ from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
17
+
18
+ # Define the top-performing models based on evaluation
19
+ TOP_MODELS = [
20
+ {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
21
+ {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
22
+ {"image_backbone": "resnet50", "audio_backbone": "transformer"}
23
+ ]
24
+
25
+ # Define the MoE Model
26
+ class WatermelonMoEModel(torch.nn.Module):
27
+ def __init__(self, model_configs, model_dir="models", weights=None):
28
+ """
29
+ Mixture of Experts model that combines multiple backbone models.
30
+
31
+ Args:
32
+ model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
33
+ model_dir: Directory where model checkpoints are stored
34
+ weights: Optional list of weights for each model (None for equal weighting)
35
+ """
36
+ super(WatermelonMoEModel, self).__init__()
37
+ self.models = []
38
+ self.model_configs = model_configs
39
+
40
+ # Load each model
41
+ for config in model_configs:
42
+ img_backbone = config["image_backbone"]
43
+ audio_backbone = config["audio_backbone"]
44
+
45
+ # Initialize model
46
+ model = WatermelonModelModular(img_backbone, audio_backbone)
47
+
48
+ # Load weights
49
+ model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
50
+ if os.path.exists(model_path):
51
+ print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
52
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
53
+ else:
54
+ print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
55
+ continue
56
+
57
+ model.eval() # Set to evaluation mode
58
+ self.models.append(model)
59
+
60
+ # Set model weights (uniform by default)
61
+ if weights:
62
+ assert len(weights) == len(self.models), "Number of weights must match number of models"
63
+ self.weights = weights
64
+ else:
65
+ self.weights = [1.0 / len(self.models)] * len(self.models)
66
+
67
+ print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
68
+ print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
69
+
70
+ def forward(self, mfcc, image):
71
+ """
72
+ Forward pass through the MoE model.
73
+ Returns the weighted average of all model outputs.
74
+ """
75
+ outputs = []
76
+
77
+ # Get outputs from each model
78
+ with torch.no_grad():
79
+ for i, model in enumerate(self.models):
80
+ output = model(mfcc, image)
81
+ outputs.append(output * self.weights[i])
82
+
83
+ # Return weighted average
84
+ return torch.sum(torch.stack(outputs), dim=0)
85
+
86
+ # Modified version of process_audio_data specifically for the app to handle various tensor shapes
87
+ def app_process_audio_data(waveform, sample_rate):
88
+ """Modified version of process_audio_data for the app that handles different tensor dimensions"""
89
+ try:
90
+ print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
91
+
92
+ # Handle different tensor dimensions
93
+ if waveform.dim() == 3:
94
+ print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
95
+ # For 3D tensor, take the first item (batch dimension)
96
+ waveform = waveform[0]
97
+
98
+ if waveform.dim() == 2:
99
+ # Use the first channel for stereo audio
100
+ waveform = waveform[0]
101
+ print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
102
+
103
+ # Resample to 16kHz if needed
104
+ resample_rate = 16000
105
+ if sample_rate != resample_rate:
106
+ print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
107
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
108
+
109
+ # Ensure 3 seconds of audio
110
+ if waveform.size(0) < 3 * resample_rate:
111
+ print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
112
+ waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
113
+ else:
114
+ print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
115
+ waveform = waveform[: 3 * resample_rate]
116
+
117
+ # Apply MFCC transformation
118
+ print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
119
+ mfcc_transform = torchaudio.transforms.MFCC(
120
+ sample_rate=resample_rate,
121
+ n_mfcc=13,
122
+ melkwargs={
123
+ "n_fft": 256,
124
+ "win_length": 256,
125
+ "hop_length": 128,
126
+ "n_mels": 40,
127
+ }
128
+ )
129
+
130
+ mfcc = mfcc_transform(waveform)
131
+ print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
132
+
133
+ return mfcc
134
+ except Exception as e:
135
+ import traceback
136
+ print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
137
+ print(traceback.format_exc())
138
+ return None
139
+
140
+ # Using the decorator for GPU acceleration
141
+ @spaces.GPU
142
+ def predict_sugar_content(audio, image, model_dir="models", weights=None):
143
+ """Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
144
+ try:
145
+ # Check CUDA availability inside the GPU-decorated function
146
+ if torch.cuda.is_available():
147
+ device = torch.device("cuda")
148
+ print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
149
+ else:
150
+ device = torch.device("cpu")
151
+ print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
152
+
153
+ # Load MoE model
154
+ moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
155
+ moe_model.to(device)
156
+ moe_model.eval()
157
+ print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
158
+
159
+ # Debug information about input types
160
+ print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
161
+ print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
162
+ print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
163
+ if isinstance(image, np.ndarray):
164
+ print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
165
+
166
+ # Handle different audio input formats
167
+ if isinstance(audio, tuple) and len(audio) == 2:
168
+ # Standard Gradio format: (sample_rate, audio_data)
169
+ sample_rate, audio_data = audio
170
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
171
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
172
+ elif isinstance(audio, tuple) and len(audio) > 2:
173
+ # Sometimes Gradio returns (sample_rate, audio_data, other_info...)
174
+ sample_rate, audio_data = audio[0], audio[-1]
175
+ print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
176
+ print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
177
+ elif isinstance(audio, str):
178
+ # Direct path to audio file
179
+ audio_data, sample_rate = torchaudio.load(audio)
180
+ print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
181
+ else:
182
+ return f"Error: Unsupported audio format. Got {type(audio)}"
183
+
184
+ # Create a temporary file path for the audio and image
185
+ temp_dir = "temp"
186
+ os.makedirs(temp_dir, exist_ok=True)
187
+
188
+ temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
189
+ temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
190
+
191
+ # Import necessary libraries
192
+ from PIL import Image
193
+
194
+ # Audio handling - direct processing from the data in memory
195
+ if isinstance(audio_data, np.ndarray):
196
+ # Convert numpy array to tensor
197
+ print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
198
+ audio_tensor = torch.tensor(audio_data).float()
199
+
200
+ # Handle different audio dimensions
201
+ if audio_data.ndim == 1:
202
+ # Single channel audio
203
+ audio_tensor = audio_tensor.unsqueeze(0)
204
+ elif audio_data.ndim == 2:
205
+ # Ensure channels are first dimension
206
+ if audio_data.shape[0] > audio_data.shape[1]:
207
+ # More rows than columns, probably (samples, channels)
208
+ audio_tensor = torch.tensor(audio_data.T).float()
209
+ else:
210
+ # Already a tensor
211
+ audio_tensor = audio_data.float()
212
+
213
+ print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
214
+
215
+ # Skip saving/loading and process directly
216
+ mfcc = app_process_audio_data(audio_tensor, sample_rate)
217
+ print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
218
+
219
+ # Image handling
220
+ if isinstance(image, np.ndarray):
221
+ print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
222
+ pil_image = Image.fromarray(image)
223
+ pil_image.save(temp_image_path)
224
+ print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
225
+ elif isinstance(image, str):
226
+ # If image is already a path
227
+ temp_image_path = image
228
+ print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
229
+ else:
230
+ return f"Error: Unsupported image format. Got {type(image)}"
231
+
232
+ # Process image
233
+ print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
234
+ image_tensor = torchvision.io.read_image(temp_image_path)
235
+ print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
236
+ image_tensor = image_tensor.float()
237
+ processed_image = process_image_data(image_tensor)
238
+ print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
239
+
240
+ # Add batch dimension for inference and move to device
241
+ if mfcc is not None:
242
+ mfcc = mfcc.unsqueeze(0).to(device)
243
+ print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
244
+
245
+ if processed_image is not None:
246
+ processed_image = processed_image.unsqueeze(0).to(device)
247
+ print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
248
+
249
+ # Run inference with MoE model
250
+ print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
251
+ if mfcc is not None and processed_image is not None:
252
+ with torch.no_grad():
253
+ brix_value = moe_model(mfcc, processed_image)
254
+ print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
255
+ else:
256
+ return "Error: Failed to process inputs. Please check the debug logs."
257
+
258
+ # Format the result with a range display
259
+ if brix_value is not None:
260
+ brix_score = brix_value.item()
261
+
262
+ # Create a header with the numerical result
263
+ result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
264
+
265
+ # Add extra info about the MoE model
266
+ result += "Using Ensemble of Top-3 Models:\n"
267
+ result += "- EfficientNet-B3 + Transformer\n"
268
+ result += "- EfficientNet-B0 + Transformer\n"
269
+ result += "- ResNet-50 + Transformer\n\n"
270
+
271
+ # Add Brix scale visualization
272
+ result += "Sugar Content Scale (in °Brix):\n"
273
+ result += "──────────────────────────────────\n"
274
+
275
+ # Create the scale display with Brix ranges
276
+ scale_ranges = [
277
+ (0, 8, "Low Sugar (< 8° Brix)"),
278
+ (8, 9, "Mild Sweetness (8-9° Brix)"),
279
+ (9, 10, "Medium Sweetness (9-10° Brix)"),
280
+ (10, 11, "Sweet (10-11° Brix)"),
281
+ (11, 13, "Very Sweet (11-13° Brix)")
282
+ ]
283
+
284
+ # Find which category the prediction falls into
285
+ user_category = None
286
+ for min_val, max_val, category_name in scale_ranges:
287
+ if min_val <= brix_score < max_val:
288
+ user_category = category_name
289
+ break
290
+ if brix_score >= scale_ranges[-1][0]: # Handle edge case
291
+ user_category = scale_ranges[-1][2]
292
+
293
+ # Display the scale with the user's result highlighted
294
+ for min_val, max_val, category_name in scale_ranges:
295
+ if category_name == user_category:
296
+ result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
297
+ else:
298
+ result += f" {min_val}-{max_val}: {category_name}\n"
299
+
300
+ result += "──────────────────────────────────\n\n"
301
+
302
+ # Add assessment of the watermelon's sugar content
303
+ if brix_score < 8:
304
+ result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
305
+ elif brix_score < 9:
306
+ result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
307
+ elif brix_score < 10:
308
+ result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
309
+ elif brix_score < 11:
310
+ result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
311
+ else:
312
+ result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
313
+
314
+ return result
315
+ else:
316
+ return "Error: Could not predict sugar content. Please try again with different inputs."
317
+
318
+ except Exception as e:
319
+ import traceback
320
+ error_msg = f"Error: {str(e)}\n\n"
321
+ error_msg += traceback.format_exc()
322
+ print(f"\033[91mERR!\033[0m: {error_msg}")
323
+ return error_msg
324
+
325
+ def create_app(model_dir="models", weights=None):
326
+ """Create and launch the Gradio interface"""
327
+ # Define the prediction function with model path
328
+ def predict_fn(audio, image):
329
+ return predict_sugar_content(audio, image, model_dir, weights)
330
+
331
+ # Create Gradio interface
332
+ with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
333
+ gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
334
+ gr.Markdown("""
335
+ This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
336
+
337
+ ## What's New
338
+ This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
339
+ - EfficientNet-B3 + Transformer
340
+ - EfficientNet-B0 + Transformer
341
+ - ResNet-50 + Transformer
342
+
343
+ The ensemble approach provides more accurate predictions than any single model!
344
+
345
+ ## Instructions:
346
+ 1. Upload or record an audio of tapping the watermelon
347
+ 2. Upload or capture an image of the watermelon
348
+ 3. Click 'Predict' to get the sugar content estimation
349
+ """)
350
+
351
+ with gr.Row():
352
+ with gr.Column():
353
+ audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
354
+ image_input = gr.Image(label="Upload or Capture Image")
355
+ submit_btn = gr.Button("Predict Sugar Content", variant="primary")
356
+
357
+ with gr.Column():
358
+ output = gr.Textbox(label="Prediction Results", lines=15)
359
+
360
+ submit_btn.click(
361
+ fn=predict_fn,
362
+ inputs=[audio_input, image_input],
363
+ outputs=output
364
+ )
365
+
366
+ gr.Markdown("""
367
+ ## Tips for best results
368
+ - For audio: Tap the watermelon with your knuckle and record the sound
369
+ - For image: Take a clear photo of the whole watermelon in good lighting
370
+
371
+ ## About Brix Measurement
372
+ Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
373
+ The average ripe watermelon has a Brix value between 9-11°.
374
+
375
+ ## About the Mixture of Experts Model
376
+ This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
377
+ Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
378
+ better than any individual model (best individual model: ~0.36 MAE).
379
+ """)
380
+
381
+ return interface
382
+
383
+ if __name__ == "__main__":
384
+ import argparse
385
+
386
+ parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
387
+ parser.add_argument(
388
+ "--model_dir",
389
+ type=str,
390
+ default="models",
391
+ help="Directory containing the model checkpoints"
392
+ )
393
+ parser.add_argument(
394
+ "--share",
395
+ action="store_true",
396
+ help="Create a shareable link for the app"
397
+ )
398
+ parser.add_argument(
399
+ "--debug",
400
+ action="store_true",
401
+ help="Enable verbose debug output"
402
+ )
403
+ parser.add_argument(
404
+ "--weighting",
405
+ type=str,
406
+ choices=["uniform", "performance"],
407
+ default="uniform",
408
+ help="How to weight the models (uniform or based on performance)"
409
+ )
410
+
411
+ args = parser.parse_args()
412
+
413
+ if args.debug:
414
+ print(f"\033[92mINFO\033[0m: Debug mode enabled")
415
+
416
+ # Check if model directory exists
417
+ if not os.path.exists(args.model_dir):
418
+ print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
419
+ sys.exit(1)
420
+
421
+ # Determine weights based on argument
422
+ weights = None
423
+ if args.weighting == "performance":
424
+ # Weights inversely proportional to the MAE (better models get higher weights)
425
+ # These are the MAE values from the evaluation results
426
+ mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
427
+
428
+ # Convert to weights (inverse of MAE, normalized)
429
+ inverse_mae = [1/mae for mae in mae_values]
430
+ total = sum(inverse_mae)
431
+ weights = [val/total for val in inverse_mae]
432
+
433
+ print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
434
+ else:
435
+ print(f"\033[92mINFO\033[0m: Using uniform weights")
436
+
437
+ # Create and launch the app
438
+ app = create_app(args.model_dir, weights)
439
+ app.launch(share=args.share)
backbone_evaluation_results.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "image_backbone": "efficientnet_b3",
4
+ "audio_backbone": "transformer",
5
+ "validation_mse": 0.21577325425086877,
6
+ "validation_mae": 0.36228722945237773,
7
+ "test_mse": 0.21746371760964395,
8
+ "test_mae": 0.36353210285305976,
9
+ "model_path": "test_models/efficientnet_b3_transformer_model.pt"
10
+ },
11
+ {
12
+ "image_backbone": "efficientnet_b0",
13
+ "audio_backbone": "transformer",
14
+ "validation_mse": 0.24033201676912797,
15
+ "validation_mae": 0.42209602166444826,
16
+ "test_mse": 0.19470563121140003,
17
+ "test_mae": 0.37649240642786025,
18
+ "model_path": "test_models/efficientnet_b0_transformer_model.pt"
19
+ },
20
+ {
21
+ "image_backbone": "resnet50",
22
+ "audio_backbone": "transformer",
23
+ "validation_mse": 0.22672857019381645,
24
+ "validation_mae": 0.3926378931754675,
25
+ "test_mse": 0.22427306957542897,
26
+ "test_mae": 0.39585837423801423,
27
+ "model_path": "test_models/resnet50_transformer_model.pt"
28
+ },
29
+ {
30
+ "image_backbone": "resnet50",
31
+ "audio_backbone": "bidirectional_lstm",
32
+ "validation_mse": 0.2967155438203078,
33
+ "validation_mae": 0.3850937023376807,
34
+ "test_mse": 0.36476454623043536,
35
+ "test_mae": 0.425818096101284,
36
+ "model_path": "test_models/resnet50_bidirectional_lstm_model.pt"
37
+ },
38
+ {
39
+ "image_backbone": "efficientnet_b0",
40
+ "audio_backbone": "bidirectional_lstm",
41
+ "validation_mse": 0.5120524473679371,
42
+ "validation_mae": 0.5665570046657171,
43
+ "test_mse": 0.5059382550418376,
44
+ "test_mae": 0.555050653219223,
45
+ "model_path": "test_models/efficientnet_b0_bidirectional_lstm_model.pt"
46
+ },
47
+ {
48
+ "image_backbone": "efficientnet_b3",
49
+ "audio_backbone": "bidirectional_lstm",
50
+ "validation_mse": 0.8020018790012751,
51
+ "validation_mae": 0.7953977386156718,
52
+ "test_mse": 0.7042828559875488,
53
+ "test_mae": 0.7441241115331649,
54
+ "model_path": "test_models/efficientnet_b3_bidirectional_lstm_model.pt"
55
+ },
56
+ {
57
+ "image_backbone": "efficientnet_b0",
58
+ "audio_backbone": "gru",
59
+ "validation_mse": 1.1340507984161377,
60
+ "validation_mae": 0.8290961503982544,
61
+ "test_mse": 0.9705999374389649,
62
+ "test_mae": 0.7704607486724854,
63
+ "model_path": "test_models/efficientnet_b0_gru_model.pt"
64
+ },
65
+ {
66
+ "image_backbone": "efficientnet_b0",
67
+ "audio_backbone": "lstm",
68
+ "validation_mse": 2.787272185087204,
69
+ "validation_mae": 1.5404645502567291,
70
+ "test_mse": 2.901867628097534,
71
+ "test_mae": 1.5843785762786866,
72
+ "model_path": "test_models/efficientnet_b0_lstm_model.pt"
73
+ },
74
+ {
75
+ "image_backbone": "resnet50",
76
+ "audio_backbone": "gru",
77
+ "validation_mse": 3.9335442543029786,
78
+ "validation_mae": 1.8762320041656495,
79
+ "test_mse": 3.72695152759552,
80
+ "test_mae": 1.8381730556488036,
81
+ "model_path": "test_models/resnet50_gru_model.pt"
82
+ },
83
+ {
84
+ "image_backbone": "resnet50",
85
+ "audio_backbone": "lstm",
86
+ "validation_mse": 6.088638782501221,
87
+ "validation_mae": 2.3887929677963258,
88
+ "test_mse": 6.1847597599029545,
89
+ "test_mae": 2.418113374710083,
90
+ "model_path": "test_models/resnet50_lstm_model.pt"
91
+ },
92
+ {
93
+ "image_backbone": "efficientnet_b3",
94
+ "audio_backbone": "gru",
95
+ "validation_mse": 104.58460273742676,
96
+ "validation_mae": 10.183499813079834,
97
+ "test_mse": 104.58482055664062,
98
+ "test_mae": 10.180697345733643,
99
+ "model_path": "test_models/efficientnet_b3_gru_model.pt"
100
+ },
101
+ {
102
+ "image_backbone": "efficientnet_b3",
103
+ "audio_backbone": "lstm",
104
+ "validation_mse": 105.40057525634765,
105
+ "validation_mae": 10.221695899963379,
106
+ "test_mse": 105.17274551391601,
107
+ "test_mae": 10.21053056716919,
108
+ "model_path": "test_models/efficientnet_b3_lstm_model.pt"
109
+ }
110
+ ]
evaluate_backbones.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import torchvision
5
+ import numpy as np
6
+ import time
7
+ import json
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import sys
10
+ from tqdm import tqdm
11
+
12
+ # Add parent directory to path to import the preprocess functions
13
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+ from preprocess import process_audio_data, process_image_data
15
+
16
+ # Print library versions
17
+ print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
18
+ print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
19
+ print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
20
+
21
+ # Device selection
22
+ device = torch.device(
23
+ "cuda"
24
+ if torch.cuda.is_available()
25
+ else "mps" if torch.backends.mps.is_available() else "cpu"
26
+ )
27
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
28
+
29
+ # Hyperparameters
30
+ batch_size = 16
31
+ epochs = 1 # Just one epoch for evaluation
32
+ learning_rate = 0.0001
33
+
34
+
35
+ class WatermelonDataset(Dataset):
36
+ def __init__(self, data_dir):
37
+ self.data_dir = data_dir
38
+ self.samples = []
39
+
40
+ # Walk through the directory structure
41
+ for sweetness_dir in os.listdir(data_dir):
42
+ sweetness = float(sweetness_dir)
43
+ sweetness_path = os.path.join(data_dir, sweetness_dir)
44
+
45
+ if os.path.isdir(sweetness_path):
46
+ for id_dir in os.listdir(sweetness_path):
47
+ id_path = os.path.join(sweetness_path, id_dir)
48
+
49
+ if os.path.isdir(id_path):
50
+ audio_file = os.path.join(id_path, f"{id_dir}.wav")
51
+ image_file = os.path.join(id_path, f"{id_dir}.jpg")
52
+
53
+ if os.path.exists(audio_file) and os.path.exists(image_file):
54
+ self.samples.append((audio_file, image_file, sweetness))
55
+
56
+ print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}")
57
+
58
+ def __len__(self):
59
+ return len(self.samples)
60
+
61
+ def __getitem__(self, idx):
62
+ audio_path, image_path, label = self.samples[idx]
63
+
64
+ # Load and process audio
65
+ try:
66
+ waveform, sample_rate = torchaudio.load(audio_path)
67
+ mfcc = process_audio_data(waveform, sample_rate)
68
+
69
+ # Load and process image
70
+ image = torchvision.io.read_image(image_path)
71
+ image = image.float()
72
+ processed_image = process_image_data(image)
73
+
74
+ return mfcc, processed_image, torch.tensor(label).float()
75
+ except Exception as e:
76
+ print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}")
77
+ # Return a fallback sample or skip this sample
78
+ # For simplicity, we'll return the first sample again
79
+ if idx == 0: # Prevent infinite recursion
80
+ raise e
81
+ return self.__getitem__(0)
82
+
83
+
84
+ # Define available backbone models
85
+ IMAGE_BACKBONES = {
86
+ "resnet50": {
87
+ "model": torchvision.models.resnet50,
88
+ "weights": torchvision.models.ResNet50_Weights.DEFAULT,
89
+ "output_dim": lambda model: model.fc.in_features
90
+ },
91
+ "efficientnet_b0": {
92
+ "model": torchvision.models.efficientnet_b0,
93
+ "weights": torchvision.models.EfficientNet_B0_Weights.DEFAULT,
94
+ "output_dim": lambda model: model.classifier[1].in_features
95
+ },
96
+ "efficientnet_b3": {
97
+ "model": torchvision.models.efficientnet_b3,
98
+ "weights": torchvision.models.EfficientNet_B3_Weights.DEFAULT,
99
+ "output_dim": lambda model: model.classifier[1].in_features
100
+ }
101
+ }
102
+
103
+ AUDIO_BACKBONES = {
104
+ "lstm": {
105
+ "model": lambda input_size, hidden_size: torch.nn.LSTM(
106
+ input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
107
+ ),
108
+ "output_dim": lambda hidden_size: hidden_size
109
+ },
110
+ "gru": {
111
+ "model": lambda input_size, hidden_size: torch.nn.GRU(
112
+ input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
113
+ ),
114
+ "output_dim": lambda hidden_size: hidden_size
115
+ },
116
+ "bidirectional_lstm": {
117
+ "model": lambda input_size, hidden_size: torch.nn.LSTM(
118
+ input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True
119
+ ),
120
+ "output_dim": lambda hidden_size: hidden_size * 2 # * 2 because bidirectional
121
+ },
122
+ "transformer": {
123
+ "model": lambda input_size, hidden_size: torch.nn.TransformerEncoder(
124
+ torch.nn.TransformerEncoderLayer(
125
+ d_model=input_size, nhead=8, dim_feedforward=hidden_size, batch_first=True
126
+ ),
127
+ num_layers=2
128
+ ),
129
+ "output_dim": lambda hidden_size: 376 # Using input_size (mfcc dimensions)
130
+ }
131
+ }
132
+
133
+
134
+ class WatermelonModelModular(torch.nn.Module):
135
+ def __init__(self, image_backbone_name, audio_backbone_name, audio_hidden_size=128):
136
+ super(WatermelonModelModular, self).__init__()
137
+
138
+ # Audio backbone setup
139
+ self.audio_backbone_name = audio_backbone_name
140
+ self.audio_hidden_size = audio_hidden_size
141
+ self.audio_input_size = 376 # From MFCC dimensions
142
+
143
+ audio_config = AUDIO_BACKBONES[audio_backbone_name]
144
+ self.audio_backbone = audio_config["model"](self.audio_input_size, self.audio_hidden_size)
145
+ audio_output_dim = audio_config["output_dim"](self.audio_hidden_size)
146
+
147
+ self.audio_fc = torch.nn.Linear(audio_output_dim, 128)
148
+
149
+ # Image backbone setup
150
+ self.image_backbone_name = image_backbone_name
151
+ image_config = IMAGE_BACKBONES[image_backbone_name]
152
+
153
+ self.image_backbone = image_config["model"](weights=image_config["weights"])
154
+
155
+ # Replace final layer for all image backbones to get features
156
+ if image_backbone_name.startswith("resnet"):
157
+ self.image_output_dim = image_config["output_dim"](self.image_backbone)
158
+ self.image_backbone.fc = torch.nn.Identity()
159
+ elif image_backbone_name.startswith("efficientnet"):
160
+ self.image_output_dim = image_config["output_dim"](self.image_backbone)
161
+ self.image_backbone.classifier = torch.nn.Identity()
162
+ elif image_backbone_name.startswith("convnext"):
163
+ self.image_output_dim = image_config["output_dim"](self.image_backbone)
164
+ self.image_backbone.classifier = torch.nn.Identity()
165
+ elif image_backbone_name.startswith("swin"):
166
+ self.image_output_dim = image_config["output_dim"](self.image_backbone)
167
+ self.image_backbone.head = torch.nn.Identity()
168
+
169
+ self.image_fc = torch.nn.Linear(self.image_output_dim, 128)
170
+
171
+ # Fully connected layers for final prediction
172
+ self.fc1 = torch.nn.Linear(256, 64)
173
+ self.fc2 = torch.nn.Linear(64, 1)
174
+ self.relu = torch.nn.ReLU()
175
+
176
+ def forward(self, mfcc, image):
177
+ # Audio backbone processing
178
+ if self.audio_backbone_name == "lstm" or self.audio_backbone_name == "gru":
179
+ audio_output, _ = self.audio_backbone(mfcc)
180
+ audio_output = audio_output[:, -1, :] # Use the output of the last time step
181
+ elif self.audio_backbone_name == "bidirectional_lstm":
182
+ audio_output, _ = self.audio_backbone(mfcc)
183
+ audio_output = audio_output[:, -1, :] # Use the output of the last time step
184
+ elif self.audio_backbone_name == "transformer":
185
+ audio_output = self.audio_backbone(mfcc)
186
+ audio_output = audio_output.mean(dim=1) # Average pooling over sequence length
187
+
188
+ audio_output = self.audio_fc(audio_output)
189
+
190
+ # Image backbone processing
191
+ image_output = self.image_backbone(image)
192
+ image_output = self.image_fc(image_output)
193
+
194
+ # Concatenate audio and image outputs
195
+ merged = torch.cat((audio_output, image_output), dim=1)
196
+
197
+ # Fully connected layers
198
+ output = self.relu(self.fc1(merged))
199
+ output = self.fc2(output)
200
+
201
+ return output
202
+
203
+
204
+ def evaluate_model(data_dir, image_backbone, audio_backbone, audio_hidden_size=128, save_model_dir=None):
205
+ # Adjust batch size based on model complexity to avoid OOM errors
206
+ adjusted_batch_size = batch_size
207
+
208
+ # Models that typically require more memory get smaller batch sizes
209
+ if image_backbone in ["swin_b", "convnext_base"] or audio_backbone in ["transformer", "bidirectional_lstm"]:
210
+ adjusted_batch_size = max(4, batch_size // 2) # At least batch size of 4, but reduce by half if needed
211
+ print(f"\033[92mINFO\033[0m: Adjusted batch size to {adjusted_batch_size} for larger model")
212
+
213
+ # Create dataset
214
+ dataset = WatermelonDataset(data_dir)
215
+ n_samples = len(dataset)
216
+
217
+ # Split dataset
218
+ train_size = int(0.7 * n_samples)
219
+ val_size = int(0.2 * n_samples)
220
+ test_size = n_samples - train_size - val_size
221
+
222
+ train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
223
+ dataset, [train_size, val_size, test_size]
224
+ )
225
+
226
+ train_loader = DataLoader(train_dataset, batch_size=adjusted_batch_size, shuffle=True)
227
+ val_loader = DataLoader(val_dataset, batch_size=adjusted_batch_size, shuffle=False)
228
+ test_loader = DataLoader(test_dataset, batch_size=adjusted_batch_size, shuffle=False)
229
+
230
+ # Initialize model
231
+ model = WatermelonModelModular(image_backbone, audio_backbone, audio_hidden_size).to(device)
232
+
233
+ # Loss function and optimizer
234
+ criterion = torch.nn.MSELoss()
235
+ mae_criterion = torch.nn.L1Loss() # For MAE evaluation
236
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
237
+
238
+ print(f"\033[92mINFO\033[0m: Evaluating model with {image_backbone} (image) and {audio_backbone} (audio)")
239
+ print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
240
+ print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
241
+ print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
242
+ print(f"\033[92mINFO\033[0m: Batch size: {adjusted_batch_size}")
243
+
244
+ # Training loop
245
+ print(f"\033[92mINFO\033[0m: Training for evaluation...")
246
+ model.train()
247
+ running_loss = 0.0
248
+
249
+ # Wrap with tqdm for progress visualization
250
+ train_iterator = tqdm(train_loader, desc="Training")
251
+
252
+ for i, (mfcc, image, label) in enumerate(train_iterator):
253
+ try:
254
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
255
+
256
+ optimizer.zero_grad()
257
+ output = model(mfcc, image)
258
+ label = label.view(-1, 1).float()
259
+ loss = criterion(output, label)
260
+ loss.backward()
261
+ optimizer.step()
262
+
263
+ running_loss += loss.item()
264
+ train_iterator.set_postfix({"Loss": f"{loss.item():.4f}"})
265
+
266
+ # Clear memory after each batch
267
+ if device.type == 'cuda':
268
+ del mfcc, image, label, output, loss
269
+ torch.cuda.empty_cache()
270
+
271
+ except Exception as e:
272
+ print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}")
273
+ # Clear memory in case of error
274
+ if device.type == 'cuda':
275
+ torch.cuda.empty_cache()
276
+ continue
277
+
278
+ # Validation phase
279
+ print(f"\033[92mINFO\033[0m: Validating...")
280
+ model.eval()
281
+ val_loss = 0.0
282
+ val_mae = 0.0
283
+
284
+ val_iterator = tqdm(val_loader, desc="Validation")
285
+
286
+ with torch.no_grad():
287
+ for i, (mfcc, image, label) in enumerate(val_iterator):
288
+ try:
289
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
290
+ output = model(mfcc, image)
291
+ label = label.view(-1, 1).float()
292
+
293
+ # Calculate MSE loss
294
+ loss = criterion(output, label)
295
+ val_loss += loss.item()
296
+
297
+ # Calculate MAE
298
+ mae = mae_criterion(output, label)
299
+ val_mae += mae.item()
300
+
301
+ val_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
302
+
303
+ # Clear memory after each batch
304
+ if device.type == 'cuda':
305
+ del mfcc, image, label, output, loss, mae
306
+ torch.cuda.empty_cache()
307
+
308
+ except Exception as e:
309
+ print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}")
310
+ # Clear memory in case of error
311
+ if device.type == 'cuda':
312
+ torch.cuda.empty_cache()
313
+ continue
314
+
315
+ avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
316
+ avg_val_mae = val_mae / len(val_loader) if len(val_loader) > 0 else float('inf')
317
+
318
+ # Test phase
319
+ print(f"\033[92mINFO\033[0m: Testing...")
320
+ model.eval()
321
+ test_loss = 0.0
322
+ test_mae = 0.0
323
+
324
+ test_iterator = tqdm(test_loader, desc="Testing")
325
+
326
+ with torch.no_grad():
327
+ for i, (mfcc, image, label) in enumerate(test_iterator):
328
+ try:
329
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
330
+ output = model(mfcc, image)
331
+ label = label.view(-1, 1).float()
332
+
333
+ # Calculate MSE loss
334
+ loss = criterion(output, label)
335
+ test_loss += loss.item()
336
+
337
+ # Calculate MAE
338
+ mae = mae_criterion(output, label)
339
+ test_mae += mae.item()
340
+
341
+ test_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
342
+
343
+ # Clear memory after each batch
344
+ if device.type == 'cuda':
345
+ del mfcc, image, label, output, loss, mae
346
+ torch.cuda.empty_cache()
347
+
348
+ except Exception as e:
349
+ print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
350
+ # Clear memory in case of error
351
+ if device.type == 'cuda':
352
+ torch.cuda.empty_cache()
353
+ continue
354
+
355
+ avg_test_loss = test_loss / len(test_loader) if len(test_loader) > 0 else float('inf')
356
+ avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
357
+
358
+ results = {
359
+ "image_backbone": image_backbone,
360
+ "audio_backbone": audio_backbone,
361
+ "validation_mse": avg_val_loss,
362
+ "validation_mae": avg_val_mae,
363
+ "test_mse": avg_test_loss,
364
+ "test_mae": avg_test_mae
365
+ }
366
+
367
+ print(f"\033[92mINFO\033[0m: Evaluation Results:")
368
+ print(f"Image Backbone: {image_backbone}")
369
+ print(f"Audio Backbone: {audio_backbone}")
370
+ print(f"Validation MSE: {avg_val_loss:.4f}")
371
+ print(f"Validation MAE: {avg_val_mae:.4f}")
372
+ print(f"Test MSE: {avg_test_loss:.4f}")
373
+ print(f"Test MAE: {avg_test_mae:.4f}")
374
+
375
+ # Save model if save_model_dir is provided
376
+ if save_model_dir:
377
+ os.makedirs(save_model_dir, exist_ok=True)
378
+ model_filename = f"{image_backbone}_{audio_backbone}_model.pt"
379
+ model_path = os.path.join(save_model_dir, model_filename)
380
+ torch.save(model.state_dict(), model_path)
381
+ print(f"\033[92mINFO\033[0m: Model saved to {model_path}")
382
+
383
+ # Add model path to results
384
+ results["model_path"] = model_path
385
+
386
+ # Clean up memory before returning
387
+ if device.type == 'cuda':
388
+ del model, optimizer, criterion, mae_criterion
389
+ torch.cuda.empty_cache()
390
+
391
+ return results
392
+
393
+
394
+ def evaluate_all_combinations(data_dir, image_backbones=None, audio_backbones=None, save_model_dir="test_models", results_file="backbone_evaluation_results.json"):
395
+ if image_backbones is None:
396
+ image_backbones = list(IMAGE_BACKBONES.keys())
397
+
398
+ if audio_backbones is None:
399
+ audio_backbones = list(AUDIO_BACKBONES.keys())
400
+
401
+ # Create directory for saving models
402
+ if save_model_dir:
403
+ os.makedirs(save_model_dir, exist_ok=True)
404
+
405
+ # Load previous results if the file exists
406
+ results = []
407
+ evaluated_combinations = set()
408
+
409
+ if os.path.exists(results_file):
410
+ try:
411
+ with open(results_file, 'r') as f:
412
+ results = json.load(f)
413
+ evaluated_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in results}
414
+ print(f"\033[92mINFO\033[0m: Loaded {len(results)} previous results from {results_file}")
415
+ except Exception as e:
416
+ print(f"\033[91mERR!\033[0m: Error loading previous results from {results_file}: {e}")
417
+ results = []
418
+ evaluated_combinations = set()
419
+ else:
420
+ print(f"\033[93mWARN\033[0m: Results file '{results_file}' does not exist. Starting with empty results.")
421
+
422
+ # Create combinations to evaluate, skipping any that have already been evaluated
423
+ combinations = [(img, aud) for img in image_backbones for aud in audio_backbones
424
+ if (img, aud) not in evaluated_combinations]
425
+
426
+ if len(combinations) < len(image_backbones) * len(audio_backbones):
427
+ print(f"\033[92mINFO\033[0m: Skipping {len(evaluated_combinations)} already evaluated combinations")
428
+
429
+ print(f"\033[92mINFO\033[0m: Will evaluate {len(combinations)} combinations")
430
+
431
+ for image_backbone, audio_backbone in combinations:
432
+ print(f"\033[92mINFO\033[0m: Evaluating {image_backbone} + {audio_backbone}")
433
+ try:
434
+ # Clean GPU memory before each model evaluation
435
+ if torch.cuda.is_available():
436
+ torch.cuda.empty_cache()
437
+ print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
438
+ # Print memory usage for debugging
439
+ print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
440
+ print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
441
+
442
+ result = evaluate_model(data_dir, image_backbone, audio_backbone, save_model_dir=save_model_dir)
443
+ results.append(result)
444
+
445
+ # Save results after each evaluation
446
+ save_results(results, results_file)
447
+ print(f"\033[92mINFO\033[0m: Updated results saved to {results_file}")
448
+
449
+ # Force garbage collection to free memory
450
+ import gc
451
+ gc.collect()
452
+ if torch.cuda.is_available():
453
+ torch.cuda.empty_cache()
454
+ print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
455
+ # Print memory usage for debugging
456
+ print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
457
+ print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
458
+
459
+ except Exception as e:
460
+ print(f"\033[91mERR!\033[0m: Error evaluating {image_backbone} + {audio_backbone}: {e}")
461
+ print(f"\033[91mERR!\033[0m: To continue from this point, use --start_from={image_backbone}:{audio_backbone}")
462
+
463
+ # Force garbage collection to free memory even if there's an error
464
+ import gc
465
+ gc.collect()
466
+ if torch.cuda.is_available():
467
+ torch.cuda.empty_cache()
468
+ print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
469
+
470
+ continue
471
+
472
+ # Sort results by test MAE (ascending)
473
+ results.sort(key=lambda x: x["test_mae"])
474
+
475
+ # Save final sorted results
476
+ save_results(results, results_file)
477
+
478
+ print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
479
+ print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
480
+ print("="*60)
481
+
482
+ for result in results:
483
+ print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
484
+
485
+ return results
486
+
487
+
488
+ def save_results(results, filename="backbone_evaluation_results.json"):
489
+ """Save evaluation results to a JSON file."""
490
+ with open(filename, 'w') as f:
491
+ json.dump(results, f, indent=4)
492
+ print(f"\033[92mINFO\033[0m: Results saved to {filename}")
493
+
494
+
495
+ if __name__ == "__main__":
496
+ import argparse
497
+
498
+ parser = argparse.ArgumentParser(description="Evaluate Different Backbones for Watermelon Sweetness Prediction")
499
+ parser.add_argument(
500
+ "--data_dir",
501
+ type=str,
502
+ default="../cleaned",
503
+ help="Path to the cleaned dataset directory"
504
+ )
505
+ parser.add_argument(
506
+ "--image_backbone",
507
+ type=str,
508
+ default=None,
509
+ help="Specific image backbone to evaluate (leave empty to evaluate all available)"
510
+ )
511
+ parser.add_argument(
512
+ "--audio_backbone",
513
+ type=str,
514
+ default=None,
515
+ help="Specific audio backbone to evaluate (leave empty to evaluate all available)"
516
+ )
517
+ parser.add_argument(
518
+ "--evaluate_all",
519
+ action="store_true",
520
+ help="Evaluate all combinations of backbones"
521
+ )
522
+ parser.add_argument(
523
+ "--start_from",
524
+ type=str,
525
+ default=None,
526
+ help="Start evaluation from a specific combination, format: 'image_backbone:audio_backbone'"
527
+ )
528
+ parser.add_argument(
529
+ "--prioritize_efficient",
530
+ action="store_true",
531
+ help="Prioritize more efficient models first to avoid memory issues"
532
+ )
533
+ parser.add_argument(
534
+ "--results_file",
535
+ type=str,
536
+ default="backbone_evaluation_results.json",
537
+ help="File to save the evaluation results"
538
+ )
539
+ parser.add_argument(
540
+ "--load_previous_results",
541
+ action="store_true",
542
+ help="Load previous results from results_file if it exists"
543
+ )
544
+ parser.add_argument(
545
+ "--model_dir",
546
+ type=str,
547
+ default="test_models",
548
+ help="Directory to save model checkpoints"
549
+ )
550
+
551
+ args = parser.parse_args()
552
+
553
+ # Create model directory if it doesn't exist
554
+ if args.model_dir:
555
+ os.makedirs(args.model_dir, exist_ok=True)
556
+
557
+ print(f"\033[92mINFO\033[0m: === Available Image Backbones ===")
558
+ for name in IMAGE_BACKBONES.keys():
559
+ print(f"- {name}")
560
+
561
+ print(f"\033[92mINFO\033[0m: === Available Audio Backbones ===")
562
+ for name in AUDIO_BACKBONES.keys():
563
+ print(f"- {name}")
564
+
565
+ if args.evaluate_all:
566
+ evaluate_all_combinations(args.data_dir, results_file=args.results_file, save_model_dir=args.model_dir)
567
+ elif args.image_backbone and args.audio_backbone:
568
+ result = evaluate_model(args.data_dir, args.image_backbone, args.audio_backbone, save_model_dir=args.model_dir)
569
+ save_results([result], args.results_file)
570
+ else:
571
+ # Define a default set of backbones to evaluate if not specified
572
+ if args.prioritize_efficient:
573
+ # Start with less memory-intensive models
574
+ image_backbones = ["resnet50", "efficientnet_b0", "resnet101", "efficientnet_b3", "convnext_base", "swin_b"]
575
+ audio_backbones = ["lstm", "gru", "bidirectional_lstm", "transformer"]
576
+ else:
577
+ # Default selection focusing on better performance models
578
+ image_backbones = ["resnet101", "efficientnet_b3", "swin_b"]
579
+ audio_backbones = ["lstm", "bidirectional_lstm", "transformer"]
580
+
581
+ # Create all combinations
582
+ combinations = [(img, aud) for img in image_backbones for aud in audio_backbones]
583
+
584
+ # Load previous results if requested and file exists
585
+ previous_results = []
586
+ previous_combinations = set()
587
+ if args.load_previous_results:
588
+ try:
589
+ if os.path.exists(args.results_file):
590
+ with open(args.results_file, 'r') as f:
591
+ previous_results = json.load(f)
592
+ previous_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in previous_results}
593
+ print(f"\033[92mINFO\033[0m: Loaded {len(previous_results)} previous results")
594
+ else:
595
+ print(f"\033[93mWARN\033[0m: Results file '{args.results_file}' does not exist. Starting with empty results.")
596
+ except Exception as e:
597
+ print(f"\033[91mERR!\033[0m: Error loading previous results: {e}")
598
+ previous_results = []
599
+ previous_combinations = set()
600
+
601
+ # If starting from a specific point
602
+ if args.start_from:
603
+ try:
604
+ start_img, start_aud = args.start_from.split(':')
605
+ start_idx = combinations.index((start_img, start_aud))
606
+ combinations = combinations[start_idx:]
607
+ print(f"\033[92mINFO\033[0m: Starting from combination: {start_img} (image) + {start_aud} (audio)")
608
+ except (ValueError, IndexError):
609
+ print(f"\033[91mERR!\033[0m: Invalid start_from format or combination not found. Format should be 'image_backbone:audio_backbone'")
610
+ print(f"\033[91mERR!\033[0m: Continuing with all combinations.")
611
+
612
+ # Skip combinations that have already been evaluated
613
+ if previous_combinations:
614
+ original_count = len(combinations)
615
+ combinations = [(img, aud) for img, aud in combinations if (img, aud) not in previous_combinations]
616
+ print(f"\033[92mINFO\033[0m: Skipping {original_count - len(combinations)} already evaluated combinations")
617
+
618
+ # Evaluate each combination
619
+ results = previous_results.copy()
620
+
621
+ for img_backbone, audio_backbone in combinations:
622
+ print(f"\033[92mINFO\033[0m: Evaluating {img_backbone} + {audio_backbone}")
623
+ try:
624
+ # Clean GPU memory before each model evaluation
625
+ if torch.cuda.is_available():
626
+ torch.cuda.empty_cache()
627
+ print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
628
+ print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
629
+ print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
630
+
631
+ result = evaluate_model(args.data_dir, img_backbone, audio_backbone, save_model_dir=args.model_dir)
632
+ results.append(result)
633
+
634
+ # Save results after each evaluation
635
+ save_results(results, args.results_file)
636
+
637
+ # Force garbage collection to free memory
638
+ import gc
639
+ gc.collect()
640
+ if torch.cuda.is_available():
641
+ torch.cuda.empty_cache()
642
+ print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
643
+ print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
644
+ print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
645
+
646
+ except Exception as e:
647
+ print(f"\033[91mERR!\033[0m: Error evaluating {img_backbone} + {audio_backbone}: {e}")
648
+ print(f"\033[91mERR!\033[0m: To continue from this point later, use --start_from={img_backbone}:{audio_backbone}")
649
+
650
+ # Force garbage collection to free memory even if there's an error
651
+ import gc
652
+ gc.collect()
653
+ if torch.cuda.is_available():
654
+ torch.cuda.empty_cache()
655
+ print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
656
+
657
+ continue
658
+
659
+ # Sort results by test MAE (ascending)
660
+ results.sort(key=lambda x: x["test_mae"])
661
+
662
+ # Save final sorted results
663
+ save_results(results, args.results_file)
664
+
665
+ print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
666
+ print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
667
+ print("="*60)
668
+
669
+ for result in results:
670
+ print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
models/.nfs00000001a1a17512003726ad ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02999bd33592de717dc1ec8054dc570193074c3f25a7283b3daa580b727b7134
3
+ size 96095572
models/.nfs00000001a234d9cd003726ac ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df632222fa87e09e635f90e5cce14bdd9fd34b442bf18daaf13e54dedfed132
3
+ size 96095572
models/.nfs00000001a2a11ea9003726ae ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80f999a1540c42ed74491692aa66c3b5a6171f972bdf47c9d52556fe1673c8dd
3
+ size 96095572
models/efficientnet_b0_transformer_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eec8d23f6454198e147db3ff31e497a0fed8cc0fa690f58e2576e9190ca54aa7
3
+ size 22597034
models/efficientnet_b3_transformer_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da70bf6bef70cfa3795e566fd58523a9b41b01c151fb37fd3b255262c2b47451
3
+ size 49751930
models/resnet50_transformer_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cec4fe964defc58fea1f6c26c714c27680a4aa81b131795e8cbeadb6e7be9bd5
3
+ size 101004668
moe_evaluation_results.json ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "moe_test_mae": 0.19680618420243262,
3
+ "moe_test_mse": 0.05606407420709729,
4
+ "true_labels": [
5
+ 10.5,
6
+ 9.399999618530273,
7
+ 11.600000381469727,
8
+ 8.699999809265137,
9
+ 10.399999618530273,
10
+ 10.800000190734863,
11
+ 11.600000381469727,
12
+ 10.5,
13
+ 11.600000381469727,
14
+ 11.100000381469727,
15
+ 10.399999618530273,
16
+ 10.5,
17
+ 11.0,
18
+ 10.5,
19
+ 10.899999618530273,
20
+ 10.5,
21
+ 11.100000381469727,
22
+ 9.600000381469727,
23
+ 12.699999809265137,
24
+ 10.0,
25
+ 10.300000190734863,
26
+ 10.399999618530273,
27
+ 9.399999618530273,
28
+ 10.800000190734863,
29
+ 10.0,
30
+ 11.600000381469727,
31
+ 10.0,
32
+ 10.399999618530273,
33
+ 9.399999618530273,
34
+ 10.399999618530273,
35
+ 10.300000190734863,
36
+ 9.399999618530273,
37
+ 10.899999618530273,
38
+ 9.0,
39
+ 10.300000190734863,
40
+ 10.899999618530273,
41
+ 11.0,
42
+ 12.699999809265137,
43
+ 10.399999618530273,
44
+ 9.600000381469727,
45
+ 8.699999809265137,
46
+ 10.199999809265137,
47
+ 10.300000190734863,
48
+ 11.600000381469727,
49
+ 9.0,
50
+ 9.0,
51
+ 11.0,
52
+ 8.699999809265137,
53
+ 9.699999809265137,
54
+ 10.399999618530273,
55
+ 10.0,
56
+ 11.600000381469727,
57
+ 9.399999618530273,
58
+ 9.0,
59
+ 10.300000190734863,
60
+ 10.5,
61
+ 10.399999618530273,
62
+ 11.0,
63
+ 10.899999618530273,
64
+ 9.399999618530273,
65
+ 8.699999809265137,
66
+ 10.300000190734863,
67
+ 9.699999809265137,
68
+ 10.300000190734863,
69
+ 9.399999618530273,
70
+ 10.300000190734863,
71
+ 9.399999618530273,
72
+ 10.0,
73
+ 10.399999618530273,
74
+ 10.199999809265137,
75
+ 11.0,
76
+ 12.699999809265137,
77
+ 12.699999809265137,
78
+ 10.0,
79
+ 11.0,
80
+ 9.0,
81
+ 10.0,
82
+ 10.5,
83
+ 11.600000381469727,
84
+ 9.399999618530273,
85
+ 10.0,
86
+ 11.0,
87
+ 11.100000381469727,
88
+ 10.899999618530273,
89
+ 9.399999618530273,
90
+ 10.300000190734863,
91
+ 9.399999618530273,
92
+ 8.699999809265137,
93
+ 10.0,
94
+ 12.699999809265137,
95
+ 12.699999809265137,
96
+ 9.699999809265137,
97
+ 9.399999618530273,
98
+ 11.0,
99
+ 9.399999618530273,
100
+ 9.0,
101
+ 11.100000381469727,
102
+ 10.300000190734863,
103
+ 10.300000190734863,
104
+ 10.300000190734863,
105
+ 10.0,
106
+ 9.399999618530273,
107
+ 9.399999618530273,
108
+ 10.899999618530273,
109
+ 11.0,
110
+ 9.699999809265137,
111
+ 12.699999809265137,
112
+ 10.5,
113
+ 11.0,
114
+ 10.899999618530273,
115
+ 12.699999809265137,
116
+ 10.899999618530273,
117
+ 11.0,
118
+ 10.300000190734863,
119
+ 11.0,
120
+ 9.699999809265137,
121
+ 10.300000190734863,
122
+ 10.300000190734863,
123
+ 10.199999809265137,
124
+ 10.199999809265137,
125
+ 10.899999618530273,
126
+ 10.5,
127
+ 11.0,
128
+ 8.699999809265137,
129
+ 9.699999809265137,
130
+ 12.699999809265137,
131
+ 11.600000381469727,
132
+ 10.899999618530273,
133
+ 11.0,
134
+ 9.399999618530273,
135
+ 10.300000190734863,
136
+ 12.699999809265137,
137
+ 10.199999809265137,
138
+ 10.199999809265137,
139
+ 10.800000190734863,
140
+ 8.699999809265137,
141
+ 9.0,
142
+ 11.0,
143
+ 9.399999618530273,
144
+ 10.800000190734863,
145
+ 11.100000381469727,
146
+ 11.100000381469727,
147
+ 10.199999809265137,
148
+ 9.399999618530273,
149
+ 10.199999809265137,
150
+ 10.199999809265137,
151
+ 9.399999618530273,
152
+ 10.899999618530273,
153
+ 10.199999809265137,
154
+ 11.100000381469727,
155
+ 11.600000381469727,
156
+ 8.699999809265137,
157
+ 11.600000381469727,
158
+ 10.199999809265137,
159
+ 9.399999618530273,
160
+ 9.699999809265137,
161
+ 9.399999618530273
162
+ ],
163
+ "moe_predictions": [
164
+ 10.906482696533203,
165
+ 9.413387298583984,
166
+ 11.58445930480957,
167
+ 8.627098083496094,
168
+ 10.55517578125,
169
+ 10.969362258911133,
170
+ 11.596641540527344,
171
+ 10.598587036132812,
172
+ 11.712945938110352,
173
+ 11.415390968322754,
174
+ 10.500967979431152,
175
+ 10.939116477966309,
176
+ 11.23089599609375,
177
+ 10.928877830505371,
178
+ 11.180931091308594,
179
+ 10.805574417114258,
180
+ 11.44560432434082,
181
+ 9.797750473022461,
182
+ 12.00424575805664,
183
+ 9.924805641174316,
184
+ 10.419149398803711,
185
+ 10.459878921508789,
186
+ 9.774242401123047,
187
+ 10.985288619995117,
188
+ 10.047812461853027,
189
+ 11.745304107666016,
190
+ 10.191004753112793,
191
+ 10.527164459228516,
192
+ 9.581968307495117,
193
+ 10.483012199401855,
194
+ 10.368606567382812,
195
+ 9.450727462768555,
196
+ 11.197010040283203,
197
+ 9.173027038574219,
198
+ 10.50676441192627,
199
+ 11.195816040039062,
200
+ 11.227279663085938,
201
+ 13.106525421142578,
202
+ 10.4664945602417,
203
+ 9.891031265258789,
204
+ 8.75540542602539,
205
+ 10.572815895080566,
206
+ 10.214585304260254,
207
+ 12.000329971313477,
208
+ 8.887301445007324,
209
+ 8.929031372070312,
210
+ 11.054266929626465,
211
+ 8.85447883605957,
212
+ 9.515145301818848,
213
+ 10.480228424072266,
214
+ 10.193933486938477,
215
+ 11.7305908203125,
216
+ 9.437666893005371,
217
+ 9.13387680053711,
218
+ 10.629348754882812,
219
+ 10.703892707824707,
220
+ 10.539461135864258,
221
+ 11.135326385498047,
222
+ 11.19705867767334,
223
+ 9.558942794799805,
224
+ 8.898516654968262,
225
+ 10.628425598144531,
226
+ 9.657480239868164,
227
+ 10.513351440429688,
228
+ 9.459192276000977,
229
+ 10.358184814453125,
230
+ 9.432706832885742,
231
+ 10.078161239624023,
232
+ 10.572355270385742,
233
+ 10.58112907409668,
234
+ 10.910698890686035,
235
+ 13.053973197937012,
236
+ 12.972726821899414,
237
+ 10.170805931091309,
238
+ 11.225208282470703,
239
+ 8.872610092163086,
240
+ 10.091118812561035,
241
+ 10.724177360534668,
242
+ 11.729219436645508,
243
+ 9.66834545135498,
244
+ 10.027229309082031,
245
+ 11.232885360717773,
246
+ 11.518696784973145,
247
+ 11.261479377746582,
248
+ 9.523242950439453,
249
+ 10.484042167663574,
250
+ 9.522797584533691,
251
+ 8.75236988067627,
252
+ 10.083819389343262,
253
+ 13.073421478271484,
254
+ 13.001571655273438,
255
+ 9.905550003051758,
256
+ 9.318197250366211,
257
+ 11.141549110412598,
258
+ 9.754105567932129,
259
+ 9.013923645019531,
260
+ 11.429242134094238,
261
+ 10.375783920288086,
262
+ 10.526394844055176,
263
+ 10.307140350341797,
264
+ 10.169934272766113,
265
+ 9.429258346557617,
266
+ 9.29328441619873,
267
+ 11.136444091796875,
268
+ 11.040485382080078,
269
+ 9.723966598510742,
270
+ 12.936074256896973,
271
+ 10.913898468017578,
272
+ 11.255935668945312,
273
+ 11.032815933227539,
274
+ 12.95362663269043,
275
+ 10.942233085632324,
276
+ 11.014484405517578,
277
+ 10.47386646270752,
278
+ 11.207697868347168,
279
+ 9.531013488769531,
280
+ 10.512401580810547,
281
+ 10.791257858276367,
282
+ 10.385677337646484,
283
+ 10.393269538879395,
284
+ 11.13322639465332,
285
+ 10.893503189086914,
286
+ 11.24067497253418,
287
+ 8.767911911010742,
288
+ 9.76015853881836,
289
+ 13.095734596252441,
290
+ 11.651636123657227,
291
+ 11.08572006225586,
292
+ 10.958650588989258,
293
+ 9.548912048339844,
294
+ 10.243309020996094,
295
+ 13.102086067199707,
296
+ 10.579414367675781,
297
+ 10.406577110290527,
298
+ 11.255165100097656,
299
+ 8.494292259216309,
300
+ 8.890151023864746,
301
+ 11.146952629089355,
302
+ 9.766341209411621,
303
+ 11.163339614868164,
304
+ 11.502073287963867,
305
+ 11.408285140991211,
306
+ 10.383015632629395,
307
+ 9.54578971862793,
308
+ 10.56948184967041,
309
+ 10.558614730834961,
310
+ 9.794357299804688,
311
+ 10.885274887084961,
312
+ 10.377969741821289,
313
+ 11.410195350646973,
314
+ 11.537992477416992,
315
+ 8.826037406921387,
316
+ 12.070415496826172,
317
+ 10.559798240661621,
318
+ 9.605077743530273,
319
+ 9.737533569335938,
320
+ 9.520374298095703
321
+ ],
322
+ "individual_predictions": {
323
+ "efficientnet_b3_transformer": [
324
+ 10.619565963745117,
325
+ 9.285565376281738,
326
+ 11.017762184143066,
327
+ 8.358080863952637,
328
+ 9.92147159576416,
329
+ 10.68340015411377,
330
+ 11.023524284362793,
331
+ 10.292417526245117,
332
+ 10.513864517211914,
333
+ 10.958821296691895,
334
+ 10.322061538696289,
335
+ 10.383071899414062,
336
+ 10.330121040344238,
337
+ 10.344510078430176,
338
+ 11.309442520141602,
339
+ 10.321882247924805,
340
+ 10.974185943603516,
341
+ 9.367315292358398,
342
+ 11.474529266357422,
343
+ 9.296891212463379,
344
+ 10.27892780303955,
345
+ 10.14356803894043,
346
+ 9.155308723449707,
347
+ 10.249421119689941,
348
+ 9.534292221069336,
349
+ 11.197205543518066,
350
+ 9.988767623901367,
351
+ 10.485107421875,
352
+ 9.040623664855957,
353
+ 10.171326637268066,
354
+ 10.153056144714355,
355
+ 9.17545223236084,
356
+ 10.604523658752441,
357
+ 8.7711763381958,
358
+ 10.127464294433594,
359
+ 11.29480266571045,
360
+ 10.326626777648926,
361
+ 13.54947566986084,
362
+ 10.142123222351074,
363
+ 9.914827346801758,
364
+ 7.935253620147705,
365
+ 10.513096809387207,
366
+ 9.79228687286377,
367
+ 11.721403121948242,
368
+ 7.996966361999512,
369
+ 8.011720657348633,
370
+ 10.551737785339355,
371
+ 8.663973808288574,
372
+ 8.74413776397705,
373
+ 10.276195526123047,
374
+ 10.136805534362793,
375
+ 11.221556663513184,
376
+ 8.912840843200684,
377
+ 8.619383811950684,
378
+ 10.178643226623535,
379
+ 10.311914443969727,
380
+ 10.487189292907715,
381
+ 10.548056602478027,
382
+ 11.258485794067383,
383
+ 9.288726806640625,
384
+ 8.140922546386719,
385
+ 10.216073989868164,
386
+ 9.068129539489746,
387
+ 10.33917236328125,
388
+ 9.11395263671875,
389
+ 10.140262603759766,
390
+ 8.864439010620117,
391
+ 9.560175895690918,
392
+ 10.1554594039917,
393
+ 10.011631965637207,
394
+ 10.838635444641113,
395
+ 13.890799522399902,
396
+ 13.743374824523926,
397
+ 10.119439125061035,
398
+ 11.073603630065918,
399
+ 7.99126672744751,
400
+ 10.012906074523926,
401
+ 10.309550285339355,
402
+ 10.537038803100586,
403
+ 9.361739158630371,
404
+ 9.594813346862793,
405
+ 10.32430362701416,
406
+ 11.0283842086792,
407
+ 11.271435737609863,
408
+ 9.267289161682129,
409
+ 10.143651962280273,
410
+ 9.201630592346191,
411
+ 8.489853858947754,
412
+ 9.663308143615723,
413
+ 13.539351463317871,
414
+ 13.890753746032715,
415
+ 9.300865173339844,
416
+ 8.978877067565918,
417
+ 10.455121994018555,
418
+ 9.145268440246582,
419
+ 8.390588760375977,
420
+ 10.97396183013916,
421
+ 10.023279190063477,
422
+ 10.194899559020996,
423
+ 9.974883079528809,
424
+ 10.101761817932129,
425
+ 9.511059761047363,
426
+ 8.89189624786377,
427
+ 10.77907657623291,
428
+ 10.7083158493042,
429
+ 9.067532539367676,
430
+ 13.406800270080566,
431
+ 10.60212516784668,
432
+ 10.704161643981934,
433
+ 11.133363723754883,
434
+ 13.293631553649902,
435
+ 9.996685981750488,
436
+ 10.766114234924316,
437
+ 10.15234088897705,
438
+ 11.180027961730957,
439
+ 8.875227928161621,
440
+ 10.376603126525879,
441
+ 10.074305534362793,
442
+ 10.001667022705078,
443
+ 10.027312278747559,
444
+ 10.606922149658203,
445
+ 10.565585136413574,
446
+ 10.699769020080566,
447
+ 8.507576942443848,
448
+ 9.084380149841309,
449
+ 13.500945091247559,
450
+ 11.240296363830566,
451
+ 10.65023136138916,
452
+ 10.248372077941895,
453
+ 9.269180297851562,
454
+ 9.840892791748047,
455
+ 13.547538757324219,
456
+ 9.992758750915527,
457
+ 10.026358604431152,
458
+ 10.71567440032959,
459
+ 8.320480346679688,
460
+ 8.000975608825684,
461
+ 10.548954963684082,
462
+ 9.176098823547363,
463
+ 11.098072052001953,
464
+ 11.02483081817627,
465
+ 11.12319278717041,
466
+ 9.996392250061035,
467
+ 9.263312339782715,
468
+ 10.517735481262207,
469
+ 9.8799409866333,
470
+ 9.319127082824707,
471
+ 9.990796089172363,
472
+ 9.982155799865723,
473
+ 11.105603218078613,
474
+ 10.747210502624512,
475
+ 8.343344688415527,
476
+ 11.73001480102539,
477
+ 10.511062622070312,
478
+ 9.331645965576172,
479
+ 9.131060600280762,
480
+ 8.956952095031738
481
+ ],
482
+ "efficientnet_b0_transformer": [
483
+ 11.040512084960938,
484
+ 9.555410385131836,
485
+ 11.689399719238281,
486
+ 8.434002876281738,
487
+ 11.386773109436035,
488
+ 10.940624237060547,
489
+ 11.708887100219727,
490
+ 11.056541442871094,
491
+ 12.392988204956055,
492
+ 11.619367599487305,
493
+ 10.591476440429688,
494
+ 11.15828800201416,
495
+ 11.810995101928711,
496
+ 11.26023006439209,
497
+ 11.246732711791992,
498
+ 11.448994636535645,
499
+ 11.935430526733398,
500
+ 10.085470199584961,
501
+ 12.768455505371094,
502
+ 10.39224910736084,
503
+ 10.590924263000488,
504
+ 10.642997741699219,
505
+ 9.948995590209961,
506
+ 11.38804817199707,
507
+ 10.38807487487793,
508
+ 11.55557632446289,
509
+ 10.514514923095703,
510
+ 10.37149429321289,
511
+ 9.95881462097168,
512
+ 10.645825386047363,
513
+ 10.480897903442383,
514
+ 9.64439868927002,
515
+ 11.213277816772461,
516
+ 9.551204681396484,
517
+ 10.929215431213379,
518
+ 11.268585205078125,
519
+ 11.799053192138672,
520
+ 12.975137710571289,
521
+ 10.657550811767578,
522
+ 9.907003402709961,
523
+ 9.108478546142578,
524
+ 10.350242614746094,
525
+ 10.475027084350586,
526
+ 12.249593734741211,
527
+ 9.311214447021484,
528
+ 9.402128219604492,
529
+ 11.460792541503906,
530
+ 8.638538360595703,
531
+ 10.098196029663086,
532
+ 10.429000854492188,
533
+ 10.63322639465332,
534
+ 11.521190643310547,
535
+ 9.934067726135254,
536
+ 9.390719413757324,
537
+ 10.85897445678711,
538
+ 10.96368408203125,
539
+ 10.440620422363281,
540
+ 11.39995002746582,
541
+ 11.138040542602539,
542
+ 9.738420486450195,
543
+ 9.13027286529541,
544
+ 10.834165573120117,
545
+ 9.734615325927734,
546
+ 10.535043716430664,
547
+ 9.7576904296875,
548
+ 10.504064559936523,
549
+ 9.726502418518066,
550
+ 10.391711235046387,
551
+ 10.526286125183105,
552
+ 10.450986862182617,
553
+ 10.732028007507324,
554
+ 13.047806739807129,
555
+ 12.901583671569824,
556
+ 10.609762191772461,
557
+ 11.112765312194824,
558
+ 9.227752685546875,
559
+ 10.403764724731445,
560
+ 10.97991943359375,
561
+ 12.400298118591309,
562
+ 9.740009307861328,
563
+ 10.546162605285645,
564
+ 11.811308860778809,
565
+ 12.024316787719727,
566
+ 11.304412841796875,
567
+ 9.642568588256836,
568
+ 10.770721435546875,
569
+ 9.673535346984863,
570
+ 8.692492485046387,
571
+ 10.140533447265625,
572
+ 13.103691101074219,
573
+ 12.987236022949219,
574
+ 9.978914260864258,
575
+ 9.647960662841797,
576
+ 11.465564727783203,
577
+ 9.91793155670166,
578
+ 8.99271011352539,
579
+ 11.874197959899902,
580
+ 10.875059127807617,
581
+ 10.751541137695312,
582
+ 10.586625099182129,
583
+ 10.616861343383789,
584
+ 9.251531600952148,
585
+ 9.575355529785156,
586
+ 11.49870777130127,
587
+ 11.352771759033203,
588
+ 9.970162391662598,
589
+ 12.869828224182129,
590
+ 11.021011352539062,
591
+ 11.830097198486328,
592
+ 10.895241737365723,
593
+ 13.477546691894531,
594
+ 11.435956001281738,
595
+ 11.21767807006836,
596
+ 10.8616361618042,
597
+ 11.25930404663086,
598
+ 9.386629104614258,
599
+ 10.510151863098145,
600
+ 11.104487419128418,
601
+ 10.017858505249023,
602
+ 10.365488052368164,
603
+ 11.206178665161133,
604
+ 11.027682304382324,
605
+ 11.81328010559082,
606
+ 8.614967346191406,
607
+ 10.088481903076172,
608
+ 12.978555679321289,
609
+ 11.964248657226562,
610
+ 11.287935256958008,
611
+ 11.514422416687012,
612
+ 9.758452415466309,
613
+ 10.500945091247559,
614
+ 12.95924186706543,
615
+ 10.438175201416016,
616
+ 10.364145278930664,
617
+ 11.490489959716797,
618
+ 8.45285415649414,
619
+ 9.380582809448242,
620
+ 11.404769897460938,
621
+ 10.42972183227539,
622
+ 11.568924903869629,
623
+ 11.746879577636719,
624
+ 11.68482780456543,
625
+ 10.019561767578125,
626
+ 9.662923812866211,
627
+ 10.360588073730469,
628
+ 10.901131629943848,
629
+ 10.128849029541016,
630
+ 11.287601470947266,
631
+ 10.017107009887695,
632
+ 11.725995063781738,
633
+ 11.726645469665527,
634
+ 8.865287780761719,
635
+ 12.030455589294434,
636
+ 10.348114013671875,
637
+ 9.747005462646484,
638
+ 9.905638694763184,
639
+ 9.855661392211914
640
+ ],
641
+ "resnet50_transformer": [
642
+ 11.059370040893555,
643
+ 9.399184226989746,
644
+ 12.046213150024414,
645
+ 9.089208602905273,
646
+ 10.357281684875488,
647
+ 11.284062385559082,
648
+ 12.057510375976562,
649
+ 10.44680118560791,
650
+ 12.231982231140137,
651
+ 11.667984008789062,
652
+ 10.58936595916748,
653
+ 11.275989532470703,
654
+ 11.5515718460083,
655
+ 11.181893348693848,
656
+ 10.986615180969238,
657
+ 10.645844459533691,
658
+ 11.427197456359863,
659
+ 9.94046688079834,
660
+ 11.769749641418457,
661
+ 10.08527660369873,
662
+ 10.387595176696777,
663
+ 10.593070030212402,
664
+ 10.218421936035156,
665
+ 11.31839656829834,
666
+ 10.221070289611816,
667
+ 12.48313045501709,
668
+ 10.069729804992676,
669
+ 10.72489070892334,
670
+ 9.746464729309082,
671
+ 10.631884574890137,
672
+ 10.4718656539917,
673
+ 9.532330513000488,
674
+ 11.773228645324707,
675
+ 9.196700096130371,
676
+ 10.46361255645752,
677
+ 11.024060249328613,
678
+ 11.556159019470215,
679
+ 12.794964790344238,
680
+ 10.599808692932129,
681
+ 9.851262092590332,
682
+ 9.222484588623047,
683
+ 10.855106353759766,
684
+ 10.37644100189209,
685
+ 12.02999210357666,
686
+ 9.35372257232666,
687
+ 9.37324333190918,
688
+ 11.150269508361816,
689
+ 9.2609224319458,
690
+ 9.703102111816406,
691
+ 10.735487937927246,
692
+ 9.811766624450684,
693
+ 12.44902515411377,
694
+ 9.46609115600586,
695
+ 9.391528129577637,
696
+ 10.850428581237793,
697
+ 10.836078643798828,
698
+ 10.690573692321777,
699
+ 11.45797348022461,
700
+ 11.194649696350098,
701
+ 9.649679183959961,
702
+ 9.42435359954834,
703
+ 10.835038185119629,
704
+ 10.169693946838379,
705
+ 10.665839195251465,
706
+ 9.50593376159668,
707
+ 10.43022632598877,
708
+ 9.70718002319336,
709
+ 10.282594680786133,
710
+ 11.035321235656738,
711
+ 11.280767440795898,
712
+ 11.161433219909668,
713
+ 12.223311424255371,
714
+ 12.273221015930176,
715
+ 9.783215522766113,
716
+ 11.48925495147705,
717
+ 9.398808479309082,
718
+ 9.856684684753418,
719
+ 10.883062362670898,
720
+ 12.250321388244629,
721
+ 9.903286933898926,
722
+ 9.940712928771973,
723
+ 11.563044548034668,
724
+ 11.503388404846191,
725
+ 11.208588600158691,
726
+ 9.659869194030762,
727
+ 10.537753105163574,
728
+ 9.693224906921387,
729
+ 9.074763298034668,
730
+ 10.447615623474121,
731
+ 12.577223777770996,
732
+ 12.126725196838379,
733
+ 10.436871528625488,
734
+ 9.327754020690918,
735
+ 11.503960609436035,
736
+ 10.199116706848145,
737
+ 9.658470153808594,
738
+ 11.43956470489502,
739
+ 10.229013442993164,
740
+ 10.632741928100586,
741
+ 10.35991096496582,
742
+ 9.791178703308105,
743
+ 9.52518367767334,
744
+ 9.412601470947266,
745
+ 11.131546974182129,
746
+ 11.0603666305542,
747
+ 10.13420295715332,
748
+ 12.53159236907959,
749
+ 11.118557929992676,
750
+ 11.233548164367676,
751
+ 11.069842338562012,
752
+ 12.089702606201172,
753
+ 11.394057273864746,
754
+ 11.059659957885742,
755
+ 10.407622337341309,
756
+ 11.183761596679688,
757
+ 10.331181526184082,
758
+ 10.6504487991333,
759
+ 11.194979667663574,
760
+ 11.137504577636719,
761
+ 10.787008285522461,
762
+ 11.586577415466309,
763
+ 11.08724308013916,
764
+ 11.208975791931152,
765
+ 9.181191444396973,
766
+ 10.107614517211914,
767
+ 12.807703018188477,
768
+ 11.750362396240234,
769
+ 11.31899356842041,
770
+ 11.11315631866455,
771
+ 9.619100570678711,
772
+ 10.388087272644043,
773
+ 12.79947566986084,
774
+ 11.307307243347168,
775
+ 10.82922649383545,
776
+ 11.55932903289795,
777
+ 8.709542274475098,
778
+ 9.288893699645996,
779
+ 11.48713207244873,
780
+ 9.693202018737793,
781
+ 10.82302188873291,
782
+ 11.73450756072998,
783
+ 11.416834831237793,
784
+ 11.133091926574707,
785
+ 9.71113109588623,
786
+ 10.830121040344238,
787
+ 10.894770622253418,
788
+ 9.935094833374023,
789
+ 11.377425193786621,
790
+ 11.13464641571045,
791
+ 11.39898681640625,
792
+ 12.140122413635254,
793
+ 9.269479751586914,
794
+ 12.450774192810059,
795
+ 10.820216178894043,
796
+ 9.736580848693848,
797
+ 10.17590045928955,
798
+ 9.74850845336914
799
+ ]
800
+ }
801
+ }
templates/.nfs00000001a2893bde003726a5 ADDED
@@ -0,0 +1 @@
 
 
1
+
test_moe_model.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import torchvision
5
+ import numpy as np
6
+ import json
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import sys
9
+ from tqdm import tqdm
10
+
11
+ # Add parent directory to path to import the preprocess functions
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ from preprocess import process_audio_data, process_image_data
14
+
15
+ # Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file
16
+ from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
17
+
18
+ # Print library versions
19
+ print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
20
+ print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
21
+ print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
22
+
23
+ # Device selection
24
+ device = torch.device(
25
+ "cuda" if torch.cuda.is_available()
26
+ else "mps" if torch.backends.mps.is_available()
27
+ else "cpu"
28
+ )
29
+ print(f"\033[92mINFO\033[0m: Using device: {device}")
30
+
31
+ # Define the top-performing models based on the previous evaluation
32
+ TOP_MODELS = [
33
+ {"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
34
+ {"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
35
+ {"image_backbone": "resnet50", "audio_backbone": "transformer"}
36
+ ]
37
+
38
+ # Define class for the MoE model
39
+ class WatermelonMoEModel(torch.nn.Module):
40
+ def __init__(self, model_configs, model_dir="test_models", weights=None):
41
+ """
42
+ Mixture of Experts model that combines multiple backbone models.
43
+
44
+ Args:
45
+ model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
46
+ model_dir: Directory where model checkpoints are stored
47
+ weights: Optional list of weights for each model (None for equal weighting)
48
+ """
49
+ super(WatermelonMoEModel, self).__init__()
50
+ self.models = []
51
+ self.model_configs = model_configs
52
+
53
+ # Load each model
54
+ for config in model_configs:
55
+ img_backbone = config["image_backbone"]
56
+ audio_backbone = config["audio_backbone"]
57
+
58
+ # Initialize model
59
+ model = WatermelonModelModular(img_backbone, audio_backbone)
60
+
61
+ # Load weights
62
+ model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
63
+ if os.path.exists(model_path):
64
+ print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
65
+ model.load_state_dict(torch.load(model_path, map_location=device))
66
+ else:
67
+ print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
68
+ continue
69
+
70
+ model.to(device)
71
+ model.eval() # Set to evaluation mode
72
+ self.models.append(model)
73
+
74
+ # Set model weights (uniform by default)
75
+ if weights:
76
+ assert len(weights) == len(self.models), "Number of weights must match number of models"
77
+ self.weights = weights
78
+ else:
79
+ self.weights = [1.0 / len(self.models)] * len(self.models)
80
+
81
+ print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
82
+ print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
83
+
84
+ def forward(self, mfcc, image):
85
+ """
86
+ Forward pass through the MoE model.
87
+ Returns the weighted average of all model outputs.
88
+ """
89
+ outputs = []
90
+
91
+ # Get outputs from each model
92
+ with torch.no_grad():
93
+ for i, model in enumerate(self.models):
94
+ output = model(mfcc, image)
95
+ outputs.append(output * self.weights[i])
96
+
97
+ # Return weighted average
98
+ return torch.sum(torch.stack(outputs), dim=0)
99
+
100
+
101
+ def evaluate_moe_model(data_dir, model_dir="test_models", weights=None):
102
+ """
103
+ Evaluate the MoE model on the test set.
104
+ """
105
+ # Load dataset
106
+ print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}")
107
+ dataset = WatermelonDataset(data_dir)
108
+ n_samples = len(dataset)
109
+
110
+ # Split dataset
111
+ train_size = int(0.7 * n_samples)
112
+ val_size = int(0.2 * n_samples)
113
+ test_size = n_samples - train_size - val_size
114
+
115
+ _, _, test_dataset = torch.utils.data.random_split(
116
+ dataset, [train_size, val_size, test_size]
117
+ )
118
+
119
+ # Use a reasonable batch size
120
+ batch_size = 8
121
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
122
+
123
+ # Initialize MoE model
124
+ moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
125
+ moe_model.eval()
126
+
127
+ # Evaluation metrics
128
+ mae_criterion = torch.nn.L1Loss()
129
+ mse_criterion = torch.nn.MSELoss()
130
+
131
+ test_mae = 0.0
132
+ test_mse = 0.0
133
+
134
+ print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples")
135
+
136
+ # Individual model predictions for analysis
137
+ individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": []
138
+ for config in TOP_MODELS}
139
+ true_labels = []
140
+ moe_predictions = []
141
+
142
+ # Evaluation loop
143
+ test_iterator = tqdm(test_loader, desc="Testing MoE")
144
+
145
+ with torch.no_grad():
146
+ for i, (mfcc, image, label) in enumerate(test_iterator):
147
+ try:
148
+ mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
149
+
150
+ # Store individual model outputs for analysis
151
+ for j, model in enumerate(moe_model.models):
152
+ config = TOP_MODELS[j]
153
+ model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
154
+ output = model(mfcc, image)
155
+ individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
156
+
157
+ # Get MoE prediction
158
+ output = moe_model(mfcc, image)
159
+ moe_predictions.extend(output.view(-1).cpu().numpy())
160
+
161
+ # Store true labels
162
+ label = label.view(-1, 1).float()
163
+ true_labels.extend(label.view(-1).cpu().numpy())
164
+
165
+ # Calculate metrics
166
+ mae = mae_criterion(output, label)
167
+ mse = mse_criterion(output, label)
168
+
169
+ test_mae += mae.item()
170
+ test_mse += mse.item()
171
+
172
+ test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"})
173
+
174
+ # Clean up memory
175
+ if device.type == 'cuda':
176
+ del mfcc, image, label, output, mae, mse
177
+ torch.cuda.empty_cache()
178
+
179
+ except Exception as e:
180
+ print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
181
+ if device.type == 'cuda':
182
+ torch.cuda.empty_cache()
183
+ continue
184
+
185
+ # Calculate average metrics
186
+ avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
187
+ avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf')
188
+
189
+ print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===")
190
+ print(f"Test MAE: {avg_test_mae:.4f}")
191
+ print(f"Test MSE: {avg_test_mse:.4f}")
192
+
193
+ # Compare with individual models
194
+ print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===")
195
+ print(f"{'Model':<30} {'Test MAE':<15}")
196
+ print("="*45)
197
+
198
+ # Load previous results
199
+ results_file = "backbone_evaluation_results.json"
200
+ if os.path.exists(results_file):
201
+ with open(results_file, 'r') as f:
202
+ previous_results = json.load(f)
203
+
204
+ # Filter results for our top models
205
+ for config in TOP_MODELS:
206
+ img_backbone = config["image_backbone"]
207
+ audio_backbone = config["audio_backbone"]
208
+
209
+ for result in previous_results:
210
+ if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone:
211
+ print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}")
212
+
213
+ print(f"MoE (Ensemble) {avg_test_mae:<15.4f}")
214
+
215
+ # Save results and predictions
216
+ results = {
217
+ "moe_test_mae": float(avg_test_mae),
218
+ "moe_test_mse": float(avg_test_mse),
219
+ "true_labels": [float(x) for x in true_labels],
220
+ "moe_predictions": [float(x) for x in moe_predictions],
221
+ "individual_predictions": {key: [float(x) for x in values]
222
+ for key, values in individual_predictions.items()}
223
+ }
224
+
225
+ with open("moe_evaluation_results.json", 'w') as f:
226
+ json.dump(results, f, indent=4)
227
+
228
+ print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json")
229
+
230
+ return avg_test_mae, avg_test_mse
231
+
232
+
233
+ if __name__ == "__main__":
234
+ import argparse
235
+
236
+ parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction")
237
+ parser.add_argument(
238
+ "--data_dir",
239
+ type=str,
240
+ default="../cleaned",
241
+ help="Path to the cleaned dataset directory"
242
+ )
243
+ parser.add_argument(
244
+ "--model_dir",
245
+ type=str,
246
+ default="test_models",
247
+ help="Directory containing model checkpoints"
248
+ )
249
+ parser.add_argument(
250
+ "--weighting",
251
+ type=str,
252
+ choices=["uniform", "performance"],
253
+ default="uniform",
254
+ help="How to weight the models (uniform or based on performance)"
255
+ )
256
+
257
+ args = parser.parse_args()
258
+
259
+ # Determine weights based on argument
260
+ weights = None
261
+ if args.weighting == "performance":
262
+ # Weights inversely proportional to the MAE (better models get higher weights)
263
+ # These are the MAE values from the provided results
264
+ mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
265
+
266
+ # Convert to weights (inverse of MAE, normalized)
267
+ inverse_mae = [1/mae for mae in mae_values]
268
+ total = sum(inverse_mae)
269
+ weights = [val/total for val in inverse_mae]
270
+
271
+ print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
272
+ else:
273
+ print(f"\033[92mINFO\033[0m: Using uniform weights")
274
+
275
+ # Evaluate the MoE model
276
+ evaluate_moe_model(args.data_dir, args.model_dir, weights)