Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- .nfs00000001a2244b30003726a6 +1 -0
- .nfs00000001a2b1089c003726a7 +1 -0
- __pycache__/evaluate_backbones.cpython-310.pyc +0 -0
- __pycache__/preprocess.cpython-310.pyc +0 -0
- app.py +139 -43
- app_local_backup.py +100 -47
- app_moe.py +439 -0
- backbone_evaluation_results.json +110 -0
- evaluate_backbones.py +670 -0
- models/.nfs00000001a1a17512003726ad +3 -0
- models/.nfs00000001a234d9cd003726ac +3 -0
- models/.nfs00000001a2a11ea9003726ae +3 -0
- models/efficientnet_b0_transformer_model.pt +3 -0
- models/efficientnet_b3_transformer_model.pt +3 -0
- models/resnet50_transformer_model.pt +3 -0
- moe_evaluation_results.json +801 -0
- templates/.nfs00000001a2893bde003726a5 +1 -0
- test_moe_model.py +276 -0
.gitattributes
CHANGED
@@ -35,3 +35,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
temp/temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
temp/temp_image.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/.nfs00000001a1a17512003726ad filter=lfs diff=lfs merge=lfs -text
|
39 |
+
models/.nfs00000001a234d9cd003726ac filter=lfs diff=lfs merge=lfs -text
|
40 |
+
models/.nfs00000001a2a11ea9003726ae filter=lfs diff=lfs merge=lfs -text
|
.nfs00000001a2244b30003726a6
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
.nfs00000001a2b1089c003726a7
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
__pycache__/evaluate_backbones.cpython-310.pyc
ADDED
Binary file (16.9 kB). View file
|
|
__pycache__/preprocess.cpython-310.pyc
ADDED
Binary file (1.27 kB). View file
|
|
app.py
CHANGED
@@ -6,21 +6,82 @@ import gradio as gr
|
|
6 |
import torchaudio
|
7 |
import torchvision
|
8 |
import spaces
|
9 |
-
|
10 |
-
# # Import Gradio Spaces GPU decorator
|
11 |
-
# try:
|
12 |
-
# from gradio import spaces
|
13 |
-
# HAS_SPACES = True
|
14 |
-
# print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
|
15 |
-
# except ImportError:
|
16 |
-
# HAS_SPACES = False
|
17 |
-
# print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
|
18 |
|
19 |
# Add parent directory to path to import preprocess functions
|
20 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
21 |
|
22 |
-
# Import functions from
|
23 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
26 |
def app_process_audio_data(waveform, sample_rate):
|
@@ -76,15 +137,12 @@ def app_process_audio_data(waveform, sample_rate):
|
|
76 |
print(traceback.format_exc())
|
77 |
return None
|
78 |
|
79 |
-
#
|
80 |
-
from preprocess import process_image_data
|
81 |
-
|
82 |
-
# Using the decorator directly on the function definition
|
83 |
@spaces.GPU
|
84 |
-
def predict_sugar_content(audio, image,
|
85 |
-
"""Function with GPU acceleration to predict watermelon sugar content in Brix"""
|
86 |
try:
|
87 |
-
#
|
88 |
if torch.cuda.is_available():
|
89 |
device = torch.device("cuda")
|
90 |
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
@@ -92,11 +150,11 @@ def predict_sugar_content(audio, image, model_path):
|
|
92 |
device = torch.device("cpu")
|
93 |
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
94 |
|
95 |
-
# Load model
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
print(f"\033[92mINFO\033[0m: Loaded model
|
100 |
|
101 |
# Debug information about input types
|
102 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
@@ -188,11 +246,11 @@ def predict_sugar_content(audio, image, model_path):
|
|
188 |
processed_image = processed_image.unsqueeze(0).to(device)
|
189 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
190 |
|
191 |
-
# Run inference
|
192 |
-
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
|
193 |
if mfcc is not None and processed_image is not None:
|
194 |
with torch.no_grad():
|
195 |
-
brix_value =
|
196 |
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
197 |
else:
|
198 |
return "Error: Failed to process inputs. Please check the debug logs."
|
@@ -204,6 +262,12 @@ def predict_sugar_content(audio, image, model_path):
|
|
204 |
# Create a header with the numerical result
|
205 |
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
# Add Brix scale visualization
|
208 |
result += "Sugar Content Scale (in °Brix):\n"
|
209 |
result += "──────────────────────────────────\n"
|
@@ -257,22 +321,27 @@ def predict_sugar_content(audio, image, model_path):
|
|
257 |
error_msg += traceback.format_exc()
|
258 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
259 |
return error_msg
|
260 |
-
|
261 |
-
print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
|
262 |
-
|
263 |
|
264 |
-
def create_app(
|
265 |
"""Create and launch the Gradio interface"""
|
266 |
# Define the prediction function with model path
|
267 |
def predict_fn(audio, image):
|
268 |
-
return predict_sugar_content(audio, image,
|
269 |
|
270 |
# Create Gradio interface
|
271 |
-
with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
|
272 |
-
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor")
|
273 |
gr.Markdown("""
|
274 |
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
## Instructions:
|
277 |
1. Upload or record an audio of tapping the watermelon
|
278 |
2. Upload or capture an image of the watermelon
|
@@ -286,7 +355,7 @@ def create_app(model_path):
|
|
286 |
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
287 |
|
288 |
with gr.Column():
|
289 |
-
output = gr.Textbox(label="Prediction Results", lines=
|
290 |
|
291 |
submit_btn.click(
|
292 |
fn=predict_fn,
|
@@ -302,6 +371,11 @@ def create_app(model_path):
|
|
302 |
## About Brix Measurement
|
303 |
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
304 |
The average ripe watermelon has a Brix value between 9-11°.
|
|
|
|
|
|
|
|
|
|
|
305 |
""")
|
306 |
|
307 |
return interface
|
@@ -309,12 +383,12 @@ def create_app(model_path):
|
|
309 |
if __name__ == "__main__":
|
310 |
import argparse
|
311 |
|
312 |
-
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
|
313 |
parser.add_argument(
|
314 |
-
"--
|
315 |
type=str,
|
316 |
-
default="models
|
317 |
-
help="
|
318 |
)
|
319 |
parser.add_argument(
|
320 |
"--share",
|
@@ -326,18 +400,40 @@ if __name__ == "__main__":
|
|
326 |
action="store_true",
|
327 |
help="Enable verbose debug output"
|
328 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
args = parser.parse_args()
|
331 |
|
332 |
if args.debug:
|
333 |
print(f"\033[92mINFO\033[0m: Debug mode enabled")
|
334 |
|
335 |
-
# Check if model exists
|
336 |
-
if not os.path.exists(args.
|
337 |
-
print(f"\033[91mERR!\033[0m: Model not found at {args.
|
338 |
-
print("\033[92mINFO\033[0m: Please train a model first or provide a valid model path")
|
339 |
sys.exit(1)
|
340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
# Create and launch the app
|
342 |
-
app = create_app(args.
|
343 |
app.launch(share=args.share)
|
|
|
6 |
import torchaudio
|
7 |
import torchvision
|
8 |
import spaces
|
9 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Add parent directory to path to import preprocess functions
|
12 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
|
14 |
+
# Import functions from preprocess and model definitions
|
15 |
+
from preprocess import process_image_data
|
16 |
+
from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
|
17 |
+
|
18 |
+
# Define the top-performing models based on evaluation
|
19 |
+
TOP_MODELS = [
|
20 |
+
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
|
21 |
+
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
|
22 |
+
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
|
23 |
+
]
|
24 |
+
|
25 |
+
# Define the MoE Model
|
26 |
+
class WatermelonMoEModel(torch.nn.Module):
|
27 |
+
def __init__(self, model_configs, model_dir="models", weights=None):
|
28 |
+
"""
|
29 |
+
Mixture of Experts model that combines multiple backbone models.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
|
33 |
+
model_dir: Directory where model checkpoints are stored
|
34 |
+
weights: Optional list of weights for each model (None for equal weighting)
|
35 |
+
"""
|
36 |
+
super(WatermelonMoEModel, self).__init__()
|
37 |
+
self.models = []
|
38 |
+
self.model_configs = model_configs
|
39 |
+
|
40 |
+
# Load each model
|
41 |
+
for config in model_configs:
|
42 |
+
img_backbone = config["image_backbone"]
|
43 |
+
audio_backbone = config["audio_backbone"]
|
44 |
+
|
45 |
+
# Initialize model
|
46 |
+
model = WatermelonModelModular(img_backbone, audio_backbone)
|
47 |
+
|
48 |
+
# Load weights
|
49 |
+
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
50 |
+
if os.path.exists(model_path):
|
51 |
+
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
52 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
53 |
+
else:
|
54 |
+
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
55 |
+
continue
|
56 |
+
|
57 |
+
model.eval() # Set to evaluation mode
|
58 |
+
self.models.append(model)
|
59 |
+
|
60 |
+
# Set model weights (uniform by default)
|
61 |
+
if weights:
|
62 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
63 |
+
self.weights = weights
|
64 |
+
else:
|
65 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
66 |
+
|
67 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
|
68 |
+
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
69 |
+
|
70 |
+
def forward(self, mfcc, image):
|
71 |
+
"""
|
72 |
+
Forward pass through the MoE model.
|
73 |
+
Returns the weighted average of all model outputs.
|
74 |
+
"""
|
75 |
+
outputs = []
|
76 |
+
|
77 |
+
# Get outputs from each model
|
78 |
+
with torch.no_grad():
|
79 |
+
for i, model in enumerate(self.models):
|
80 |
+
output = model(mfcc, image)
|
81 |
+
outputs.append(output * self.weights[i])
|
82 |
+
|
83 |
+
# Return weighted average
|
84 |
+
return torch.sum(torch.stack(outputs), dim=0)
|
85 |
|
86 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
87 |
def app_process_audio_data(waveform, sample_rate):
|
|
|
137 |
print(traceback.format_exc())
|
138 |
return None
|
139 |
|
140 |
+
# Using the decorator for GPU acceleration
|
|
|
|
|
|
|
141 |
@spaces.GPU
|
142 |
+
def predict_sugar_content(audio, image, model_dir="models", weights=None):
|
143 |
+
"""Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
|
144 |
try:
|
145 |
+
# Check CUDA availability inside the GPU-decorated function
|
146 |
if torch.cuda.is_available():
|
147 |
device = torch.device("cuda")
|
148 |
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
|
|
150 |
device = torch.device("cpu")
|
151 |
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
152 |
|
153 |
+
# Load MoE model
|
154 |
+
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
|
155 |
+
moe_model.to(device)
|
156 |
+
moe_model.eval()
|
157 |
+
print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
|
158 |
|
159 |
# Debug information about input types
|
160 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
|
|
246 |
processed_image = processed_image.unsqueeze(0).to(device)
|
247 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
248 |
|
249 |
+
# Run inference with MoE model
|
250 |
+
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
|
251 |
if mfcc is not None and processed_image is not None:
|
252 |
with torch.no_grad():
|
253 |
+
brix_value = moe_model(mfcc, processed_image)
|
254 |
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
255 |
else:
|
256 |
return "Error: Failed to process inputs. Please check the debug logs."
|
|
|
262 |
# Create a header with the numerical result
|
263 |
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
264 |
|
265 |
+
# Add extra info about the MoE model
|
266 |
+
result += "Using Ensemble of Top-3 Models:\n"
|
267 |
+
result += "- EfficientNet-B3 + Transformer\n"
|
268 |
+
result += "- EfficientNet-B0 + Transformer\n"
|
269 |
+
result += "- ResNet-50 + Transformer\n\n"
|
270 |
+
|
271 |
# Add Brix scale visualization
|
272 |
result += "Sugar Content Scale (in °Brix):\n"
|
273 |
result += "──────────────────────────────────\n"
|
|
|
321 |
error_msg += traceback.format_exc()
|
322 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
323 |
return error_msg
|
|
|
|
|
|
|
324 |
|
325 |
+
def create_app(model_dir="models", weights=None):
|
326 |
"""Create and launch the Gradio interface"""
|
327 |
# Define the prediction function with model path
|
328 |
def predict_fn(audio, image):
|
329 |
+
return predict_sugar_content(audio, image, model_dir, weights)
|
330 |
|
331 |
# Create Gradio interface
|
332 |
+
with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
|
333 |
+
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
|
334 |
gr.Markdown("""
|
335 |
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
336 |
|
337 |
+
## What's New
|
338 |
+
This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
|
339 |
+
- EfficientNet-B3 + Transformer
|
340 |
+
- EfficientNet-B0 + Transformer
|
341 |
+
- ResNet-50 + Transformer
|
342 |
+
|
343 |
+
The ensemble approach provides more accurate predictions than any single model!
|
344 |
+
|
345 |
## Instructions:
|
346 |
1. Upload or record an audio of tapping the watermelon
|
347 |
2. Upload or capture an image of the watermelon
|
|
|
355 |
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
356 |
|
357 |
with gr.Column():
|
358 |
+
output = gr.Textbox(label="Prediction Results", lines=15)
|
359 |
|
360 |
submit_btn.click(
|
361 |
fn=predict_fn,
|
|
|
371 |
## About Brix Measurement
|
372 |
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
373 |
The average ripe watermelon has a Brix value between 9-11°.
|
374 |
+
|
375 |
+
## About the Mixture of Experts Model
|
376 |
+
This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
|
377 |
+
Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
|
378 |
+
better than any individual model (best individual model: ~0.36 MAE).
|
379 |
""")
|
380 |
|
381 |
return interface
|
|
|
383 |
if __name__ == "__main__":
|
384 |
import argparse
|
385 |
|
386 |
+
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
|
387 |
parser.add_argument(
|
388 |
+
"--model_dir",
|
389 |
type=str,
|
390 |
+
default="models",
|
391 |
+
help="Directory containing the model checkpoints"
|
392 |
)
|
393 |
parser.add_argument(
|
394 |
"--share",
|
|
|
400 |
action="store_true",
|
401 |
help="Enable verbose debug output"
|
402 |
)
|
403 |
+
parser.add_argument(
|
404 |
+
"--weighting",
|
405 |
+
type=str,
|
406 |
+
choices=["uniform", "performance"],
|
407 |
+
default="uniform",
|
408 |
+
help="How to weight the models (uniform or based on performance)"
|
409 |
+
)
|
410 |
|
411 |
args = parser.parse_args()
|
412 |
|
413 |
if args.debug:
|
414 |
print(f"\033[92mINFO\033[0m: Debug mode enabled")
|
415 |
|
416 |
+
# Check if model directory exists
|
417 |
+
if not os.path.exists(args.model_dir):
|
418 |
+
print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
|
|
|
419 |
sys.exit(1)
|
420 |
|
421 |
+
# Determine weights based on argument
|
422 |
+
weights = None
|
423 |
+
if args.weighting == "performance":
|
424 |
+
# Weights inversely proportional to the MAE (better models get higher weights)
|
425 |
+
# These are the MAE values from the evaluation results
|
426 |
+
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
|
427 |
+
|
428 |
+
# Convert to weights (inverse of MAE, normalized)
|
429 |
+
inverse_mae = [1/mae for mae in mae_values]
|
430 |
+
total = sum(inverse_mae)
|
431 |
+
weights = [val/total for val in inverse_mae]
|
432 |
+
|
433 |
+
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
|
434 |
+
else:
|
435 |
+
print(f"\033[92mINFO\033[0m: Using uniform weights")
|
436 |
+
|
437 |
# Create and launch the app
|
438 |
+
app = create_app(args.model_dir, weights)
|
439 |
app.launch(share=args.share)
|
app_local_backup.py
CHANGED
@@ -5,12 +5,22 @@ import numpy as np
|
|
5 |
import gradio as gr
|
6 |
import torchaudio
|
7 |
import torchvision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Add parent directory to path to import preprocess functions
|
10 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
11 |
|
12 |
-
# Import functions from infer_watermelon.py
|
13 |
-
from
|
14 |
|
15 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
16 |
def app_process_audio_data(waveform, sample_rate):
|
@@ -69,14 +79,25 @@ def app_process_audio_data(waveform, sample_rate):
|
|
69 |
# Similarly for images, but let's import the original one
|
70 |
from preprocess import process_image_data
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
def predict_sweetness(audio, image, model, device):
|
78 |
-
"""Predict sweetness of a watermelon from audio and image input"""
|
79 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# Debug information about input types
|
81 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
82 |
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
|
@@ -97,7 +118,6 @@ def predict_sweetness(audio, image, model, device):
|
|
97 |
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
98 |
elif isinstance(audio, str):
|
99 |
# Direct path to audio file
|
100 |
-
import torchaudio
|
101 |
audio_data, sample_rate = torchaudio.load(audio)
|
102 |
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
|
103 |
else:
|
@@ -111,9 +131,6 @@ def predict_sweetness(audio, image, model, device):
|
|
111 |
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
112 |
|
113 |
# Import necessary libraries
|
114 |
-
import torchaudio
|
115 |
-
import torchvision
|
116 |
-
import torchvision.transforms.functional as F
|
117 |
from PIL import Image
|
118 |
|
119 |
# Audio handling - direct processing from the data in memory
|
@@ -162,7 +179,7 @@ def predict_sweetness(audio, image, model, device):
|
|
162 |
processed_image = process_image_data(image_tensor)
|
163 |
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
|
164 |
|
165 |
-
# Add batch dimension for inference
|
166 |
if mfcc is not None:
|
167 |
mfcc = mfcc.unsqueeze(0).to(device)
|
168 |
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
|
@@ -172,31 +189,67 @@ def predict_sweetness(audio, image, model, device):
|
|
172 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
173 |
|
174 |
# Run inference
|
175 |
-
print(f"\033[92mDEBUG\033[0m: Running inference")
|
176 |
if mfcc is not None and processed_image is not None:
|
177 |
with torch.no_grad():
|
178 |
-
|
179 |
-
print(f"\033[92mDEBUG\033[0m: Prediction successful: {
|
180 |
else:
|
181 |
return "Error: Failed to process inputs. Please check the debug logs."
|
182 |
|
183 |
-
# Format the result
|
184 |
-
if
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
#
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
else:
|
195 |
-
result += "
|
196 |
|
197 |
return result
|
198 |
else:
|
199 |
-
return "Error: Could not predict
|
200 |
|
201 |
except Exception as e:
|
202 |
import traceback
|
@@ -204,36 +257,36 @@ def predict_sweetness(audio, image, model, device):
|
|
204 |
error_msg += traceback.format_exc()
|
205 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
206 |
return error_msg
|
|
|
|
|
|
|
207 |
|
208 |
def create_app(model_path):
|
209 |
"""Create and launch the Gradio interface"""
|
210 |
-
#
|
211 |
-
model, device = init_model(model_path)
|
212 |
-
|
213 |
-
# Define the prediction function with model and device
|
214 |
def predict_fn(audio, image):
|
215 |
-
return
|
216 |
|
217 |
# Create Gradio interface
|
218 |
-
with gr.Blocks(title="Watermelon
|
219 |
-
gr.Markdown("# 🍉 Watermelon
|
220 |
gr.Markdown("""
|
221 |
-
This app predicts the
|
222 |
|
223 |
## Instructions:
|
224 |
1. Upload or record an audio of tapping the watermelon
|
225 |
2. Upload or capture an image of the watermelon
|
226 |
-
3. Click '
|
227 |
""")
|
228 |
|
229 |
with gr.Row():
|
230 |
with gr.Column():
|
231 |
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
|
232 |
image_input = gr.Image(label="Upload or Capture Image")
|
233 |
-
submit_btn = gr.Button("Predict
|
234 |
|
235 |
with gr.Column():
|
236 |
-
output = gr.Textbox(label="Prediction Results", lines=
|
237 |
|
238 |
submit_btn.click(
|
239 |
fn=predict_fn,
|
@@ -242,13 +295,13 @@ def create_app(model_path):
|
|
242 |
)
|
243 |
|
244 |
gr.Markdown("""
|
245 |
-
##
|
246 |
-
|
247 |
-
|
248 |
-
- Audio analysis using MFCC features and LSTM neural network
|
249 |
-
- Image analysis using ResNet-50 convolutional neural network
|
250 |
|
251 |
-
|
|
|
|
|
252 |
""")
|
253 |
|
254 |
return interface
|
@@ -256,7 +309,7 @@ def create_app(model_path):
|
|
256 |
if __name__ == "__main__":
|
257 |
import argparse
|
258 |
|
259 |
-
parser = argparse.ArgumentParser(description="Watermelon
|
260 |
parser.add_argument(
|
261 |
"--model_path",
|
262 |
type=str,
|
|
|
5 |
import gradio as gr
|
6 |
import torchaudio
|
7 |
import torchvision
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
# # Import Gradio Spaces GPU decorator
|
11 |
+
# try:
|
12 |
+
# from gradio import spaces
|
13 |
+
# HAS_SPACES = True
|
14 |
+
# print("\033[92mINFO\033[0m: Gradio Spaces detected, GPU acceleration will be enabled")
|
15 |
+
# except ImportError:
|
16 |
+
# HAS_SPACES = False
|
17 |
+
# print("\033[93mWARN\033[0m: gradio.spaces not available, running without GPU optimization")
|
18 |
|
19 |
# Add parent directory to path to import preprocess functions
|
20 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
21 |
|
22 |
+
# Import functions from infer_watermelon.py and train_watermelon for the model
|
23 |
+
from train_watermelon import WatermelonModel
|
24 |
|
25 |
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
26 |
def app_process_audio_data(waveform, sample_rate):
|
|
|
79 |
# Similarly for images, but let's import the original one
|
80 |
from preprocess import process_image_data
|
81 |
|
82 |
+
# Using the decorator directly on the function definition
|
83 |
+
@spaces.GPU
|
84 |
+
def predict_sugar_content(audio, image, model_path):
|
85 |
+
"""Function with GPU acceleration to predict watermelon sugar content in Brix"""
|
|
|
|
|
|
|
86 |
try:
|
87 |
+
# Now check CUDA availability inside the GPU-decorated function
|
88 |
+
if torch.cuda.is_available():
|
89 |
+
device = torch.device("cuda")
|
90 |
+
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
91 |
+
else:
|
92 |
+
device = torch.device("cpu")
|
93 |
+
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
94 |
+
|
95 |
+
# Load model inside the function to ensure it's on the correct device
|
96 |
+
model = WatermelonModel().to(device)
|
97 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
98 |
+
model.eval()
|
99 |
+
print(f"\033[92mINFO\033[0m: Loaded model from {model_path}")
|
100 |
+
|
101 |
# Debug information about input types
|
102 |
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
103 |
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
|
|
|
118 |
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
119 |
elif isinstance(audio, str):
|
120 |
# Direct path to audio file
|
|
|
121 |
audio_data, sample_rate = torchaudio.load(audio)
|
122 |
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
|
123 |
else:
|
|
|
131 |
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
132 |
|
133 |
# Import necessary libraries
|
|
|
|
|
|
|
134 |
from PIL import Image
|
135 |
|
136 |
# Audio handling - direct processing from the data in memory
|
|
|
179 |
processed_image = process_image_data(image_tensor)
|
180 |
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
|
181 |
|
182 |
+
# Add batch dimension for inference and move to device
|
183 |
if mfcc is not None:
|
184 |
mfcc = mfcc.unsqueeze(0).to(device)
|
185 |
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
|
|
|
189 |
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
190 |
|
191 |
# Run inference
|
192 |
+
print(f"\033[92mDEBUG\033[0m: Running inference on device: {device}")
|
193 |
if mfcc is not None and processed_image is not None:
|
194 |
with torch.no_grad():
|
195 |
+
brix_value = model(mfcc, processed_image)
|
196 |
+
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
197 |
else:
|
198 |
return "Error: Failed to process inputs. Please check the debug logs."
|
199 |
|
200 |
+
# Format the result with a range display
|
201 |
+
if brix_value is not None:
|
202 |
+
brix_score = brix_value.item()
|
203 |
+
|
204 |
+
# Create a header with the numerical result
|
205 |
+
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
206 |
+
|
207 |
+
# Add Brix scale visualization
|
208 |
+
result += "Sugar Content Scale (in °Brix):\n"
|
209 |
+
result += "──────────────────────────────────\n"
|
210 |
+
|
211 |
+
# Create the scale display with Brix ranges
|
212 |
+
scale_ranges = [
|
213 |
+
(0, 8, "Low Sugar (< 8° Brix)"),
|
214 |
+
(8, 9, "Mild Sweetness (8-9° Brix)"),
|
215 |
+
(9, 10, "Medium Sweetness (9-10° Brix)"),
|
216 |
+
(10, 11, "Sweet (10-11° Brix)"),
|
217 |
+
(11, 13, "Very Sweet (11-13° Brix)")
|
218 |
+
]
|
219 |
|
220 |
+
# Find which category the prediction falls into
|
221 |
+
user_category = None
|
222 |
+
for min_val, max_val, category_name in scale_ranges:
|
223 |
+
if min_val <= brix_score < max_val:
|
224 |
+
user_category = category_name
|
225 |
+
break
|
226 |
+
if brix_score >= scale_ranges[-1][0]: # Handle edge case
|
227 |
+
user_category = scale_ranges[-1][2]
|
228 |
+
|
229 |
+
# Display the scale with the user's result highlighted
|
230 |
+
for min_val, max_val, category_name in scale_ranges:
|
231 |
+
if category_name == user_category:
|
232 |
+
result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
|
233 |
+
else:
|
234 |
+
result += f" {min_val}-{max_val}: {category_name}\n"
|
235 |
+
|
236 |
+
result += "──────────────────────────────────\n\n"
|
237 |
+
|
238 |
+
# Add assessment of the watermelon's sugar content
|
239 |
+
if brix_score < 8:
|
240 |
+
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
|
241 |
+
elif brix_score < 9:
|
242 |
+
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
|
243 |
+
elif brix_score < 10:
|
244 |
+
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
|
245 |
+
elif brix_score < 11:
|
246 |
+
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
|
247 |
else:
|
248 |
+
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
|
249 |
|
250 |
return result
|
251 |
else:
|
252 |
+
return "Error: Could not predict sugar content. Please try again with different inputs."
|
253 |
|
254 |
except Exception as e:
|
255 |
import traceback
|
|
|
257 |
error_msg += traceback.format_exc()
|
258 |
print(f"\033[91mERR!\033[0m: {error_msg}")
|
259 |
return error_msg
|
260 |
+
|
261 |
+
print("\033[92mINFO\033[0m: GPU-accelerated prediction function created with @spaces.GPU decorator")
|
262 |
+
|
263 |
|
264 |
def create_app(model_path):
|
265 |
"""Create and launch the Gradio interface"""
|
266 |
+
# Define the prediction function with model path
|
|
|
|
|
|
|
267 |
def predict_fn(audio, image):
|
268 |
+
return predict_sugar_content(audio, image, model_path)
|
269 |
|
270 |
# Create Gradio interface
|
271 |
+
with gr.Blocks(title="Watermelon Sugar Content Predictor", theme=gr.themes.Soft()) as interface:
|
272 |
+
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor")
|
273 |
gr.Markdown("""
|
274 |
+
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
275 |
|
276 |
## Instructions:
|
277 |
1. Upload or record an audio of tapping the watermelon
|
278 |
2. Upload or capture an image of the watermelon
|
279 |
+
3. Click 'Predict' to get the sugar content estimation
|
280 |
""")
|
281 |
|
282 |
with gr.Row():
|
283 |
with gr.Column():
|
284 |
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
|
285 |
image_input = gr.Image(label="Upload or Capture Image")
|
286 |
+
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
287 |
|
288 |
with gr.Column():
|
289 |
+
output = gr.Textbox(label="Prediction Results", lines=12)
|
290 |
|
291 |
submit_btn.click(
|
292 |
fn=predict_fn,
|
|
|
295 |
)
|
296 |
|
297 |
gr.Markdown("""
|
298 |
+
## Tips for best results
|
299 |
+
- For audio: Tap the watermelon with your knuckle and record the sound
|
300 |
+
- For image: Take a clear photo of the whole watermelon in good lighting
|
|
|
|
|
301 |
|
302 |
+
## About Brix Measurement
|
303 |
+
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
304 |
+
The average ripe watermelon has a Brix value between 9-11°.
|
305 |
""")
|
306 |
|
307 |
return interface
|
|
|
309 |
if __name__ == "__main__":
|
310 |
import argparse
|
311 |
|
312 |
+
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App")
|
313 |
parser.add_argument(
|
314 |
"--model_path",
|
315 |
type=str,
|
app_moe.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import torchaudio
|
7 |
+
import torchvision
|
8 |
+
import spaces
|
9 |
+
import json
|
10 |
+
|
11 |
+
# Add parent directory to path to import preprocess functions
|
12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
+
|
14 |
+
# Import functions from preprocess and model definitions
|
15 |
+
from preprocess import process_image_data
|
16 |
+
from evaluate_backbones import WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
|
17 |
+
|
18 |
+
# Define the top-performing models based on evaluation
|
19 |
+
TOP_MODELS = [
|
20 |
+
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
|
21 |
+
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
|
22 |
+
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
|
23 |
+
]
|
24 |
+
|
25 |
+
# Define the MoE Model
|
26 |
+
class WatermelonMoEModel(torch.nn.Module):
|
27 |
+
def __init__(self, model_configs, model_dir="models", weights=None):
|
28 |
+
"""
|
29 |
+
Mixture of Experts model that combines multiple backbone models.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
|
33 |
+
model_dir: Directory where model checkpoints are stored
|
34 |
+
weights: Optional list of weights for each model (None for equal weighting)
|
35 |
+
"""
|
36 |
+
super(WatermelonMoEModel, self).__init__()
|
37 |
+
self.models = []
|
38 |
+
self.model_configs = model_configs
|
39 |
+
|
40 |
+
# Load each model
|
41 |
+
for config in model_configs:
|
42 |
+
img_backbone = config["image_backbone"]
|
43 |
+
audio_backbone = config["audio_backbone"]
|
44 |
+
|
45 |
+
# Initialize model
|
46 |
+
model = WatermelonModelModular(img_backbone, audio_backbone)
|
47 |
+
|
48 |
+
# Load weights
|
49 |
+
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
50 |
+
if os.path.exists(model_path):
|
51 |
+
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
52 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
53 |
+
else:
|
54 |
+
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
55 |
+
continue
|
56 |
+
|
57 |
+
model.eval() # Set to evaluation mode
|
58 |
+
self.models.append(model)
|
59 |
+
|
60 |
+
# Set model weights (uniform by default)
|
61 |
+
if weights:
|
62 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
63 |
+
self.weights = weights
|
64 |
+
else:
|
65 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
66 |
+
|
67 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
|
68 |
+
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
69 |
+
|
70 |
+
def forward(self, mfcc, image):
|
71 |
+
"""
|
72 |
+
Forward pass through the MoE model.
|
73 |
+
Returns the weighted average of all model outputs.
|
74 |
+
"""
|
75 |
+
outputs = []
|
76 |
+
|
77 |
+
# Get outputs from each model
|
78 |
+
with torch.no_grad():
|
79 |
+
for i, model in enumerate(self.models):
|
80 |
+
output = model(mfcc, image)
|
81 |
+
outputs.append(output * self.weights[i])
|
82 |
+
|
83 |
+
# Return weighted average
|
84 |
+
return torch.sum(torch.stack(outputs), dim=0)
|
85 |
+
|
86 |
+
# Modified version of process_audio_data specifically for the app to handle various tensor shapes
|
87 |
+
def app_process_audio_data(waveform, sample_rate):
|
88 |
+
"""Modified version of process_audio_data for the app that handles different tensor dimensions"""
|
89 |
+
try:
|
90 |
+
print(f"\033[92mDEBUG\033[0m: Processing audio - Initial shape: {waveform.shape}, Sample rate: {sample_rate}")
|
91 |
+
|
92 |
+
# Handle different tensor dimensions
|
93 |
+
if waveform.dim() == 3:
|
94 |
+
print(f"\033[92mDEBUG\033[0m: Found 3D tensor, converting to 2D")
|
95 |
+
# For 3D tensor, take the first item (batch dimension)
|
96 |
+
waveform = waveform[0]
|
97 |
+
|
98 |
+
if waveform.dim() == 2:
|
99 |
+
# Use the first channel for stereo audio
|
100 |
+
waveform = waveform[0]
|
101 |
+
print(f"\033[92mDEBUG\033[0m: Using first channel, new shape: {waveform.shape}")
|
102 |
+
|
103 |
+
# Resample to 16kHz if needed
|
104 |
+
resample_rate = 16000
|
105 |
+
if sample_rate != resample_rate:
|
106 |
+
print(f"\033[92mDEBUG\033[0m: Resampling from {sample_rate}Hz to {resample_rate}Hz")
|
107 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
108 |
+
|
109 |
+
# Ensure 3 seconds of audio
|
110 |
+
if waveform.size(0) < 3 * resample_rate:
|
111 |
+
print(f"\033[92mDEBUG\033[0m: Padding audio from {waveform.size(0)} to {3 * resample_rate} samples")
|
112 |
+
waveform = torch.nn.functional.pad(waveform, (0, 3 * resample_rate - waveform.size(0)))
|
113 |
+
else:
|
114 |
+
print(f"\033[92mDEBUG\033[0m: Trimming audio from {waveform.size(0)} to {3 * resample_rate} samples")
|
115 |
+
waveform = waveform[: 3 * resample_rate]
|
116 |
+
|
117 |
+
# Apply MFCC transformation
|
118 |
+
print(f"\033[92mDEBUG\033[0m: Applying MFCC transformation")
|
119 |
+
mfcc_transform = torchaudio.transforms.MFCC(
|
120 |
+
sample_rate=resample_rate,
|
121 |
+
n_mfcc=13,
|
122 |
+
melkwargs={
|
123 |
+
"n_fft": 256,
|
124 |
+
"win_length": 256,
|
125 |
+
"hop_length": 128,
|
126 |
+
"n_mels": 40,
|
127 |
+
}
|
128 |
+
)
|
129 |
+
|
130 |
+
mfcc = mfcc_transform(waveform)
|
131 |
+
print(f"\033[92mDEBUG\033[0m: MFCC output shape: {mfcc.shape}")
|
132 |
+
|
133 |
+
return mfcc
|
134 |
+
except Exception as e:
|
135 |
+
import traceback
|
136 |
+
print(f"\033[91mERR!\033[0m: Error in audio processing: {e}")
|
137 |
+
print(traceback.format_exc())
|
138 |
+
return None
|
139 |
+
|
140 |
+
# Using the decorator for GPU acceleration
|
141 |
+
@spaces.GPU
|
142 |
+
def predict_sugar_content(audio, image, model_dir="models", weights=None):
|
143 |
+
"""Function with GPU acceleration to predict watermelon sugar content in Brix using MoE model"""
|
144 |
+
try:
|
145 |
+
# Check CUDA availability inside the GPU-decorated function
|
146 |
+
if torch.cuda.is_available():
|
147 |
+
device = torch.device("cuda")
|
148 |
+
print(f"\033[92mINFO\033[0m: CUDA is available. Using device: {device}")
|
149 |
+
else:
|
150 |
+
device = torch.device("cpu")
|
151 |
+
print(f"\033[92mINFO\033[0m: CUDA is not available. Using device: {device}")
|
152 |
+
|
153 |
+
# Load MoE model
|
154 |
+
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
|
155 |
+
moe_model.to(device)
|
156 |
+
moe_model.eval()
|
157 |
+
print(f"\033[92mINFO\033[0m: Loaded MoE model with {len(moe_model.models)} backbone models")
|
158 |
+
|
159 |
+
# Debug information about input types
|
160 |
+
print(f"\033[92mDEBUG\033[0m: Audio input type: {type(audio)}")
|
161 |
+
print(f"\033[92mDEBUG\033[0m: Audio input shape/length: {len(audio)}")
|
162 |
+
print(f"\033[92mDEBUG\033[0m: Image input type: {type(image)}")
|
163 |
+
if isinstance(image, np.ndarray):
|
164 |
+
print(f"\033[92mDEBUG\033[0m: Image input shape: {image.shape}")
|
165 |
+
|
166 |
+
# Handle different audio input formats
|
167 |
+
if isinstance(audio, tuple) and len(audio) == 2:
|
168 |
+
# Standard Gradio format: (sample_rate, audio_data)
|
169 |
+
sample_rate, audio_data = audio
|
170 |
+
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
|
171 |
+
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
172 |
+
elif isinstance(audio, tuple) and len(audio) > 2:
|
173 |
+
# Sometimes Gradio returns (sample_rate, audio_data, other_info...)
|
174 |
+
sample_rate, audio_data = audio[0], audio[-1]
|
175 |
+
print(f"\033[92mDEBUG\033[0m: Audio sample rate: {sample_rate}")
|
176 |
+
print(f"\033[92mDEBUG\033[0m: Audio data shape: {audio_data.shape}")
|
177 |
+
elif isinstance(audio, str):
|
178 |
+
# Direct path to audio file
|
179 |
+
audio_data, sample_rate = torchaudio.load(audio)
|
180 |
+
print(f"\033[92mDEBUG\033[0m: Loaded audio from path with shape: {audio_data.shape}")
|
181 |
+
else:
|
182 |
+
return f"Error: Unsupported audio format. Got {type(audio)}"
|
183 |
+
|
184 |
+
# Create a temporary file path for the audio and image
|
185 |
+
temp_dir = "temp"
|
186 |
+
os.makedirs(temp_dir, exist_ok=True)
|
187 |
+
|
188 |
+
temp_audio_path = os.path.join(temp_dir, "temp_audio.wav")
|
189 |
+
temp_image_path = os.path.join(temp_dir, "temp_image.jpg")
|
190 |
+
|
191 |
+
# Import necessary libraries
|
192 |
+
from PIL import Image
|
193 |
+
|
194 |
+
# Audio handling - direct processing from the data in memory
|
195 |
+
if isinstance(audio_data, np.ndarray):
|
196 |
+
# Convert numpy array to tensor
|
197 |
+
print(f"\033[92mDEBUG\033[0m: Converting numpy audio with shape {audio_data.shape} to tensor")
|
198 |
+
audio_tensor = torch.tensor(audio_data).float()
|
199 |
+
|
200 |
+
# Handle different audio dimensions
|
201 |
+
if audio_data.ndim == 1:
|
202 |
+
# Single channel audio
|
203 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
204 |
+
elif audio_data.ndim == 2:
|
205 |
+
# Ensure channels are first dimension
|
206 |
+
if audio_data.shape[0] > audio_data.shape[1]:
|
207 |
+
# More rows than columns, probably (samples, channels)
|
208 |
+
audio_tensor = torch.tensor(audio_data.T).float()
|
209 |
+
else:
|
210 |
+
# Already a tensor
|
211 |
+
audio_tensor = audio_data.float()
|
212 |
+
|
213 |
+
print(f"\033[92mDEBUG\033[0m: Audio tensor shape before processing: {audio_tensor.shape}")
|
214 |
+
|
215 |
+
# Skip saving/loading and process directly
|
216 |
+
mfcc = app_process_audio_data(audio_tensor, sample_rate)
|
217 |
+
print(f"\033[92mDEBUG\033[0m: MFCC tensor shape after processing: {mfcc.shape if mfcc is not None else None}")
|
218 |
+
|
219 |
+
# Image handling
|
220 |
+
if isinstance(image, np.ndarray):
|
221 |
+
print(f"\033[92mDEBUG\033[0m: Converting numpy image with shape {image.shape} to PIL")
|
222 |
+
pil_image = Image.fromarray(image)
|
223 |
+
pil_image.save(temp_image_path)
|
224 |
+
print(f"\033[92mDEBUG\033[0m: Saved image to {temp_image_path}")
|
225 |
+
elif isinstance(image, str):
|
226 |
+
# If image is already a path
|
227 |
+
temp_image_path = image
|
228 |
+
print(f"\033[92mDEBUG\033[0m: Using provided image path: {temp_image_path}")
|
229 |
+
else:
|
230 |
+
return f"Error: Unsupported image format. Got {type(image)}"
|
231 |
+
|
232 |
+
# Process image
|
233 |
+
print(f"\033[92mDEBUG\033[0m: Loading and preprocessing image from {temp_image_path}")
|
234 |
+
image_tensor = torchvision.io.read_image(temp_image_path)
|
235 |
+
print(f"\033[92mDEBUG\033[0m: Loaded image shape: {image_tensor.shape}")
|
236 |
+
image_tensor = image_tensor.float()
|
237 |
+
processed_image = process_image_data(image_tensor)
|
238 |
+
print(f"\033[92mDEBUG\033[0m: Processed image shape: {processed_image.shape if processed_image is not None else None}")
|
239 |
+
|
240 |
+
# Add batch dimension for inference and move to device
|
241 |
+
if mfcc is not None:
|
242 |
+
mfcc = mfcc.unsqueeze(0).to(device)
|
243 |
+
print(f"\033[92mDEBUG\033[0m: Final MFCC shape with batch dimension: {mfcc.shape}")
|
244 |
+
|
245 |
+
if processed_image is not None:
|
246 |
+
processed_image = processed_image.unsqueeze(0).to(device)
|
247 |
+
print(f"\033[92mDEBUG\033[0m: Final image shape with batch dimension: {processed_image.shape}")
|
248 |
+
|
249 |
+
# Run inference with MoE model
|
250 |
+
print(f"\033[92mDEBUG\033[0m: Running inference with MoE model on device: {device}")
|
251 |
+
if mfcc is not None and processed_image is not None:
|
252 |
+
with torch.no_grad():
|
253 |
+
brix_value = moe_model(mfcc, processed_image)
|
254 |
+
print(f"\033[92mDEBUG\033[0m: Prediction successful: {brix_value.item()}")
|
255 |
+
else:
|
256 |
+
return "Error: Failed to process inputs. Please check the debug logs."
|
257 |
+
|
258 |
+
# Format the result with a range display
|
259 |
+
if brix_value is not None:
|
260 |
+
brix_score = brix_value.item()
|
261 |
+
|
262 |
+
# Create a header with the numerical result
|
263 |
+
result = f"🍉 Predicted Sugar Content: {brix_score:.1f}° Brix 🍉\n\n"
|
264 |
+
|
265 |
+
# Add extra info about the MoE model
|
266 |
+
result += "Using Ensemble of Top-3 Models:\n"
|
267 |
+
result += "- EfficientNet-B3 + Transformer\n"
|
268 |
+
result += "- EfficientNet-B0 + Transformer\n"
|
269 |
+
result += "- ResNet-50 + Transformer\n\n"
|
270 |
+
|
271 |
+
# Add Brix scale visualization
|
272 |
+
result += "Sugar Content Scale (in °Brix):\n"
|
273 |
+
result += "──────────────────────────────────\n"
|
274 |
+
|
275 |
+
# Create the scale display with Brix ranges
|
276 |
+
scale_ranges = [
|
277 |
+
(0, 8, "Low Sugar (< 8° Brix)"),
|
278 |
+
(8, 9, "Mild Sweetness (8-9° Brix)"),
|
279 |
+
(9, 10, "Medium Sweetness (9-10° Brix)"),
|
280 |
+
(10, 11, "Sweet (10-11° Brix)"),
|
281 |
+
(11, 13, "Very Sweet (11-13° Brix)")
|
282 |
+
]
|
283 |
+
|
284 |
+
# Find which category the prediction falls into
|
285 |
+
user_category = None
|
286 |
+
for min_val, max_val, category_name in scale_ranges:
|
287 |
+
if min_val <= brix_score < max_val:
|
288 |
+
user_category = category_name
|
289 |
+
break
|
290 |
+
if brix_score >= scale_ranges[-1][0]: # Handle edge case
|
291 |
+
user_category = scale_ranges[-1][2]
|
292 |
+
|
293 |
+
# Display the scale with the user's result highlighted
|
294 |
+
for min_val, max_val, category_name in scale_ranges:
|
295 |
+
if category_name == user_category:
|
296 |
+
result += f"▶ {min_val}-{max_val}: {category_name} ◀ (YOUR WATERMELON)\n"
|
297 |
+
else:
|
298 |
+
result += f" {min_val}-{max_val}: {category_name}\n"
|
299 |
+
|
300 |
+
result += "──────────────────────────────────\n\n"
|
301 |
+
|
302 |
+
# Add assessment of the watermelon's sugar content
|
303 |
+
if brix_score < 8:
|
304 |
+
result += "Assessment: This watermelon has low sugar content. It may taste bland or slightly bitter."
|
305 |
+
elif brix_score < 9:
|
306 |
+
result += "Assessment: This watermelon has mild sweetness. Acceptable flavor but not very sweet."
|
307 |
+
elif brix_score < 10:
|
308 |
+
result += "Assessment: This watermelon has moderate sugar content. It should have pleasant sweetness."
|
309 |
+
elif brix_score < 11:
|
310 |
+
result += "Assessment: This watermelon has good sugar content! It should be sweet and juicy."
|
311 |
+
else:
|
312 |
+
result += "Assessment: This watermelon has excellent sugar content! Perfect choice for maximum sweetness and flavor."
|
313 |
+
|
314 |
+
return result
|
315 |
+
else:
|
316 |
+
return "Error: Could not predict sugar content. Please try again with different inputs."
|
317 |
+
|
318 |
+
except Exception as e:
|
319 |
+
import traceback
|
320 |
+
error_msg = f"Error: {str(e)}\n\n"
|
321 |
+
error_msg += traceback.format_exc()
|
322 |
+
print(f"\033[91mERR!\033[0m: {error_msg}")
|
323 |
+
return error_msg
|
324 |
+
|
325 |
+
def create_app(model_dir="models", weights=None):
|
326 |
+
"""Create and launch the Gradio interface"""
|
327 |
+
# Define the prediction function with model path
|
328 |
+
def predict_fn(audio, image):
|
329 |
+
return predict_sugar_content(audio, image, model_dir, weights)
|
330 |
+
|
331 |
+
# Create Gradio interface
|
332 |
+
with gr.Blocks(title="Watermelon Sugar Content Predictor (MoE)", theme=gr.themes.Soft()) as interface:
|
333 |
+
gr.Markdown("# 🍉 Watermelon Sugar Content Predictor (Ensemble Model)")
|
334 |
+
gr.Markdown("""
|
335 |
+
This app predicts the sugar content (in °Brix) of a watermelon based on its sound and appearance.
|
336 |
+
|
337 |
+
## What's New
|
338 |
+
This version uses a Mixture of Experts (MoE) ensemble model that combines the three best-performing models:
|
339 |
+
- EfficientNet-B3 + Transformer
|
340 |
+
- EfficientNet-B0 + Transformer
|
341 |
+
- ResNet-50 + Transformer
|
342 |
+
|
343 |
+
The ensemble approach provides more accurate predictions than any single model!
|
344 |
+
|
345 |
+
## Instructions:
|
346 |
+
1. Upload or record an audio of tapping the watermelon
|
347 |
+
2. Upload or capture an image of the watermelon
|
348 |
+
3. Click 'Predict' to get the sugar content estimation
|
349 |
+
""")
|
350 |
+
|
351 |
+
with gr.Row():
|
352 |
+
with gr.Column():
|
353 |
+
audio_input = gr.Audio(label="Upload or Record Audio", type="numpy")
|
354 |
+
image_input = gr.Image(label="Upload or Capture Image")
|
355 |
+
submit_btn = gr.Button("Predict Sugar Content", variant="primary")
|
356 |
+
|
357 |
+
with gr.Column():
|
358 |
+
output = gr.Textbox(label="Prediction Results", lines=15)
|
359 |
+
|
360 |
+
submit_btn.click(
|
361 |
+
fn=predict_fn,
|
362 |
+
inputs=[audio_input, image_input],
|
363 |
+
outputs=output
|
364 |
+
)
|
365 |
+
|
366 |
+
gr.Markdown("""
|
367 |
+
## Tips for best results
|
368 |
+
- For audio: Tap the watermelon with your knuckle and record the sound
|
369 |
+
- For image: Take a clear photo of the whole watermelon in good lighting
|
370 |
+
|
371 |
+
## About Brix Measurement
|
372 |
+
Brix (°Bx) is a measurement of sugar content in a solution. For watermelons, higher Brix values indicate sweeter fruit.
|
373 |
+
The average ripe watermelon has a Brix value between 9-11°.
|
374 |
+
|
375 |
+
## About the Mixture of Experts Model
|
376 |
+
This app uses a Mixture of Experts (MoE) model that combines predictions from multiple neural networks.
|
377 |
+
Our testing shows the ensemble approach achieves a Mean Absolute Error (MAE) of ~0.22, which is significantly
|
378 |
+
better than any individual model (best individual model: ~0.36 MAE).
|
379 |
+
""")
|
380 |
+
|
381 |
+
return interface
|
382 |
+
|
383 |
+
if __name__ == "__main__":
|
384 |
+
import argparse
|
385 |
+
|
386 |
+
parser = argparse.ArgumentParser(description="Watermelon Sugar Content Prediction App (MoE)")
|
387 |
+
parser.add_argument(
|
388 |
+
"--model_dir",
|
389 |
+
type=str,
|
390 |
+
default="models",
|
391 |
+
help="Directory containing the model checkpoints"
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--share",
|
395 |
+
action="store_true",
|
396 |
+
help="Create a shareable link for the app"
|
397 |
+
)
|
398 |
+
parser.add_argument(
|
399 |
+
"--debug",
|
400 |
+
action="store_true",
|
401 |
+
help="Enable verbose debug output"
|
402 |
+
)
|
403 |
+
parser.add_argument(
|
404 |
+
"--weighting",
|
405 |
+
type=str,
|
406 |
+
choices=["uniform", "performance"],
|
407 |
+
default="uniform",
|
408 |
+
help="How to weight the models (uniform or based on performance)"
|
409 |
+
)
|
410 |
+
|
411 |
+
args = parser.parse_args()
|
412 |
+
|
413 |
+
if args.debug:
|
414 |
+
print(f"\033[92mINFO\033[0m: Debug mode enabled")
|
415 |
+
|
416 |
+
# Check if model directory exists
|
417 |
+
if not os.path.exists(args.model_dir):
|
418 |
+
print(f"\033[91mERR!\033[0m: Model directory not found at {args.model_dir}")
|
419 |
+
sys.exit(1)
|
420 |
+
|
421 |
+
# Determine weights based on argument
|
422 |
+
weights = None
|
423 |
+
if args.weighting == "performance":
|
424 |
+
# Weights inversely proportional to the MAE (better models get higher weights)
|
425 |
+
# These are the MAE values from the evaluation results
|
426 |
+
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
|
427 |
+
|
428 |
+
# Convert to weights (inverse of MAE, normalized)
|
429 |
+
inverse_mae = [1/mae for mae in mae_values]
|
430 |
+
total = sum(inverse_mae)
|
431 |
+
weights = [val/total for val in inverse_mae]
|
432 |
+
|
433 |
+
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
|
434 |
+
else:
|
435 |
+
print(f"\033[92mINFO\033[0m: Using uniform weights")
|
436 |
+
|
437 |
+
# Create and launch the app
|
438 |
+
app = create_app(args.model_dir, weights)
|
439 |
+
app.launch(share=args.share)
|
backbone_evaluation_results.json
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"image_backbone": "efficientnet_b3",
|
4 |
+
"audio_backbone": "transformer",
|
5 |
+
"validation_mse": 0.21577325425086877,
|
6 |
+
"validation_mae": 0.36228722945237773,
|
7 |
+
"test_mse": 0.21746371760964395,
|
8 |
+
"test_mae": 0.36353210285305976,
|
9 |
+
"model_path": "test_models/efficientnet_b3_transformer_model.pt"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"image_backbone": "efficientnet_b0",
|
13 |
+
"audio_backbone": "transformer",
|
14 |
+
"validation_mse": 0.24033201676912797,
|
15 |
+
"validation_mae": 0.42209602166444826,
|
16 |
+
"test_mse": 0.19470563121140003,
|
17 |
+
"test_mae": 0.37649240642786025,
|
18 |
+
"model_path": "test_models/efficientnet_b0_transformer_model.pt"
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"image_backbone": "resnet50",
|
22 |
+
"audio_backbone": "transformer",
|
23 |
+
"validation_mse": 0.22672857019381645,
|
24 |
+
"validation_mae": 0.3926378931754675,
|
25 |
+
"test_mse": 0.22427306957542897,
|
26 |
+
"test_mae": 0.39585837423801423,
|
27 |
+
"model_path": "test_models/resnet50_transformer_model.pt"
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"image_backbone": "resnet50",
|
31 |
+
"audio_backbone": "bidirectional_lstm",
|
32 |
+
"validation_mse": 0.2967155438203078,
|
33 |
+
"validation_mae": 0.3850937023376807,
|
34 |
+
"test_mse": 0.36476454623043536,
|
35 |
+
"test_mae": 0.425818096101284,
|
36 |
+
"model_path": "test_models/resnet50_bidirectional_lstm_model.pt"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"image_backbone": "efficientnet_b0",
|
40 |
+
"audio_backbone": "bidirectional_lstm",
|
41 |
+
"validation_mse": 0.5120524473679371,
|
42 |
+
"validation_mae": 0.5665570046657171,
|
43 |
+
"test_mse": 0.5059382550418376,
|
44 |
+
"test_mae": 0.555050653219223,
|
45 |
+
"model_path": "test_models/efficientnet_b0_bidirectional_lstm_model.pt"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"image_backbone": "efficientnet_b3",
|
49 |
+
"audio_backbone": "bidirectional_lstm",
|
50 |
+
"validation_mse": 0.8020018790012751,
|
51 |
+
"validation_mae": 0.7953977386156718,
|
52 |
+
"test_mse": 0.7042828559875488,
|
53 |
+
"test_mae": 0.7441241115331649,
|
54 |
+
"model_path": "test_models/efficientnet_b3_bidirectional_lstm_model.pt"
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"image_backbone": "efficientnet_b0",
|
58 |
+
"audio_backbone": "gru",
|
59 |
+
"validation_mse": 1.1340507984161377,
|
60 |
+
"validation_mae": 0.8290961503982544,
|
61 |
+
"test_mse": 0.9705999374389649,
|
62 |
+
"test_mae": 0.7704607486724854,
|
63 |
+
"model_path": "test_models/efficientnet_b0_gru_model.pt"
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"image_backbone": "efficientnet_b0",
|
67 |
+
"audio_backbone": "lstm",
|
68 |
+
"validation_mse": 2.787272185087204,
|
69 |
+
"validation_mae": 1.5404645502567291,
|
70 |
+
"test_mse": 2.901867628097534,
|
71 |
+
"test_mae": 1.5843785762786866,
|
72 |
+
"model_path": "test_models/efficientnet_b0_lstm_model.pt"
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"image_backbone": "resnet50",
|
76 |
+
"audio_backbone": "gru",
|
77 |
+
"validation_mse": 3.9335442543029786,
|
78 |
+
"validation_mae": 1.8762320041656495,
|
79 |
+
"test_mse": 3.72695152759552,
|
80 |
+
"test_mae": 1.8381730556488036,
|
81 |
+
"model_path": "test_models/resnet50_gru_model.pt"
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"image_backbone": "resnet50",
|
85 |
+
"audio_backbone": "lstm",
|
86 |
+
"validation_mse": 6.088638782501221,
|
87 |
+
"validation_mae": 2.3887929677963258,
|
88 |
+
"test_mse": 6.1847597599029545,
|
89 |
+
"test_mae": 2.418113374710083,
|
90 |
+
"model_path": "test_models/resnet50_lstm_model.pt"
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"image_backbone": "efficientnet_b3",
|
94 |
+
"audio_backbone": "gru",
|
95 |
+
"validation_mse": 104.58460273742676,
|
96 |
+
"validation_mae": 10.183499813079834,
|
97 |
+
"test_mse": 104.58482055664062,
|
98 |
+
"test_mae": 10.180697345733643,
|
99 |
+
"model_path": "test_models/efficientnet_b3_gru_model.pt"
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"image_backbone": "efficientnet_b3",
|
103 |
+
"audio_backbone": "lstm",
|
104 |
+
"validation_mse": 105.40057525634765,
|
105 |
+
"validation_mae": 10.221695899963379,
|
106 |
+
"test_mse": 105.17274551391601,
|
107 |
+
"test_mae": 10.21053056716919,
|
108 |
+
"model_path": "test_models/efficientnet_b3_lstm_model.pt"
|
109 |
+
}
|
110 |
+
]
|
evaluate_backbones.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import torchvision
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
import json
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
import sys
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
# Add parent directory to path to import the preprocess functions
|
13 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
14 |
+
from preprocess import process_audio_data, process_image_data
|
15 |
+
|
16 |
+
# Print library versions
|
17 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
18 |
+
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
|
19 |
+
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
|
20 |
+
|
21 |
+
# Device selection
|
22 |
+
device = torch.device(
|
23 |
+
"cuda"
|
24 |
+
if torch.cuda.is_available()
|
25 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
26 |
+
)
|
27 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
28 |
+
|
29 |
+
# Hyperparameters
|
30 |
+
batch_size = 16
|
31 |
+
epochs = 1 # Just one epoch for evaluation
|
32 |
+
learning_rate = 0.0001
|
33 |
+
|
34 |
+
|
35 |
+
class WatermelonDataset(Dataset):
|
36 |
+
def __init__(self, data_dir):
|
37 |
+
self.data_dir = data_dir
|
38 |
+
self.samples = []
|
39 |
+
|
40 |
+
# Walk through the directory structure
|
41 |
+
for sweetness_dir in os.listdir(data_dir):
|
42 |
+
sweetness = float(sweetness_dir)
|
43 |
+
sweetness_path = os.path.join(data_dir, sweetness_dir)
|
44 |
+
|
45 |
+
if os.path.isdir(sweetness_path):
|
46 |
+
for id_dir in os.listdir(sweetness_path):
|
47 |
+
id_path = os.path.join(sweetness_path, id_dir)
|
48 |
+
|
49 |
+
if os.path.isdir(id_path):
|
50 |
+
audio_file = os.path.join(id_path, f"{id_dir}.wav")
|
51 |
+
image_file = os.path.join(id_path, f"{id_dir}.jpg")
|
52 |
+
|
53 |
+
if os.path.exists(audio_file) and os.path.exists(image_file):
|
54 |
+
self.samples.append((audio_file, image_file, sweetness))
|
55 |
+
|
56 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.samples)} samples from {data_dir}")
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.samples)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
audio_path, image_path, label = self.samples[idx]
|
63 |
+
|
64 |
+
# Load and process audio
|
65 |
+
try:
|
66 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
67 |
+
mfcc = process_audio_data(waveform, sample_rate)
|
68 |
+
|
69 |
+
# Load and process image
|
70 |
+
image = torchvision.io.read_image(image_path)
|
71 |
+
image = image.float()
|
72 |
+
processed_image = process_image_data(image)
|
73 |
+
|
74 |
+
return mfcc, processed_image, torch.tensor(label).float()
|
75 |
+
except Exception as e:
|
76 |
+
print(f"\033[91mERR!\033[0m: Error processing sample {idx}: {e}")
|
77 |
+
# Return a fallback sample or skip this sample
|
78 |
+
# For simplicity, we'll return the first sample again
|
79 |
+
if idx == 0: # Prevent infinite recursion
|
80 |
+
raise e
|
81 |
+
return self.__getitem__(0)
|
82 |
+
|
83 |
+
|
84 |
+
# Define available backbone models
|
85 |
+
IMAGE_BACKBONES = {
|
86 |
+
"resnet50": {
|
87 |
+
"model": torchvision.models.resnet50,
|
88 |
+
"weights": torchvision.models.ResNet50_Weights.DEFAULT,
|
89 |
+
"output_dim": lambda model: model.fc.in_features
|
90 |
+
},
|
91 |
+
"efficientnet_b0": {
|
92 |
+
"model": torchvision.models.efficientnet_b0,
|
93 |
+
"weights": torchvision.models.EfficientNet_B0_Weights.DEFAULT,
|
94 |
+
"output_dim": lambda model: model.classifier[1].in_features
|
95 |
+
},
|
96 |
+
"efficientnet_b3": {
|
97 |
+
"model": torchvision.models.efficientnet_b3,
|
98 |
+
"weights": torchvision.models.EfficientNet_B3_Weights.DEFAULT,
|
99 |
+
"output_dim": lambda model: model.classifier[1].in_features
|
100 |
+
}
|
101 |
+
}
|
102 |
+
|
103 |
+
AUDIO_BACKBONES = {
|
104 |
+
"lstm": {
|
105 |
+
"model": lambda input_size, hidden_size: torch.nn.LSTM(
|
106 |
+
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
|
107 |
+
),
|
108 |
+
"output_dim": lambda hidden_size: hidden_size
|
109 |
+
},
|
110 |
+
"gru": {
|
111 |
+
"model": lambda input_size, hidden_size: torch.nn.GRU(
|
112 |
+
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True
|
113 |
+
),
|
114 |
+
"output_dim": lambda hidden_size: hidden_size
|
115 |
+
},
|
116 |
+
"bidirectional_lstm": {
|
117 |
+
"model": lambda input_size, hidden_size: torch.nn.LSTM(
|
118 |
+
input_size=input_size, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True
|
119 |
+
),
|
120 |
+
"output_dim": lambda hidden_size: hidden_size * 2 # * 2 because bidirectional
|
121 |
+
},
|
122 |
+
"transformer": {
|
123 |
+
"model": lambda input_size, hidden_size: torch.nn.TransformerEncoder(
|
124 |
+
torch.nn.TransformerEncoderLayer(
|
125 |
+
d_model=input_size, nhead=8, dim_feedforward=hidden_size, batch_first=True
|
126 |
+
),
|
127 |
+
num_layers=2
|
128 |
+
),
|
129 |
+
"output_dim": lambda hidden_size: 376 # Using input_size (mfcc dimensions)
|
130 |
+
}
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
class WatermelonModelModular(torch.nn.Module):
|
135 |
+
def __init__(self, image_backbone_name, audio_backbone_name, audio_hidden_size=128):
|
136 |
+
super(WatermelonModelModular, self).__init__()
|
137 |
+
|
138 |
+
# Audio backbone setup
|
139 |
+
self.audio_backbone_name = audio_backbone_name
|
140 |
+
self.audio_hidden_size = audio_hidden_size
|
141 |
+
self.audio_input_size = 376 # From MFCC dimensions
|
142 |
+
|
143 |
+
audio_config = AUDIO_BACKBONES[audio_backbone_name]
|
144 |
+
self.audio_backbone = audio_config["model"](self.audio_input_size, self.audio_hidden_size)
|
145 |
+
audio_output_dim = audio_config["output_dim"](self.audio_hidden_size)
|
146 |
+
|
147 |
+
self.audio_fc = torch.nn.Linear(audio_output_dim, 128)
|
148 |
+
|
149 |
+
# Image backbone setup
|
150 |
+
self.image_backbone_name = image_backbone_name
|
151 |
+
image_config = IMAGE_BACKBONES[image_backbone_name]
|
152 |
+
|
153 |
+
self.image_backbone = image_config["model"](weights=image_config["weights"])
|
154 |
+
|
155 |
+
# Replace final layer for all image backbones to get features
|
156 |
+
if image_backbone_name.startswith("resnet"):
|
157 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
158 |
+
self.image_backbone.fc = torch.nn.Identity()
|
159 |
+
elif image_backbone_name.startswith("efficientnet"):
|
160 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
161 |
+
self.image_backbone.classifier = torch.nn.Identity()
|
162 |
+
elif image_backbone_name.startswith("convnext"):
|
163 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
164 |
+
self.image_backbone.classifier = torch.nn.Identity()
|
165 |
+
elif image_backbone_name.startswith("swin"):
|
166 |
+
self.image_output_dim = image_config["output_dim"](self.image_backbone)
|
167 |
+
self.image_backbone.head = torch.nn.Identity()
|
168 |
+
|
169 |
+
self.image_fc = torch.nn.Linear(self.image_output_dim, 128)
|
170 |
+
|
171 |
+
# Fully connected layers for final prediction
|
172 |
+
self.fc1 = torch.nn.Linear(256, 64)
|
173 |
+
self.fc2 = torch.nn.Linear(64, 1)
|
174 |
+
self.relu = torch.nn.ReLU()
|
175 |
+
|
176 |
+
def forward(self, mfcc, image):
|
177 |
+
# Audio backbone processing
|
178 |
+
if self.audio_backbone_name == "lstm" or self.audio_backbone_name == "gru":
|
179 |
+
audio_output, _ = self.audio_backbone(mfcc)
|
180 |
+
audio_output = audio_output[:, -1, :] # Use the output of the last time step
|
181 |
+
elif self.audio_backbone_name == "bidirectional_lstm":
|
182 |
+
audio_output, _ = self.audio_backbone(mfcc)
|
183 |
+
audio_output = audio_output[:, -1, :] # Use the output of the last time step
|
184 |
+
elif self.audio_backbone_name == "transformer":
|
185 |
+
audio_output = self.audio_backbone(mfcc)
|
186 |
+
audio_output = audio_output.mean(dim=1) # Average pooling over sequence length
|
187 |
+
|
188 |
+
audio_output = self.audio_fc(audio_output)
|
189 |
+
|
190 |
+
# Image backbone processing
|
191 |
+
image_output = self.image_backbone(image)
|
192 |
+
image_output = self.image_fc(image_output)
|
193 |
+
|
194 |
+
# Concatenate audio and image outputs
|
195 |
+
merged = torch.cat((audio_output, image_output), dim=1)
|
196 |
+
|
197 |
+
# Fully connected layers
|
198 |
+
output = self.relu(self.fc1(merged))
|
199 |
+
output = self.fc2(output)
|
200 |
+
|
201 |
+
return output
|
202 |
+
|
203 |
+
|
204 |
+
def evaluate_model(data_dir, image_backbone, audio_backbone, audio_hidden_size=128, save_model_dir=None):
|
205 |
+
# Adjust batch size based on model complexity to avoid OOM errors
|
206 |
+
adjusted_batch_size = batch_size
|
207 |
+
|
208 |
+
# Models that typically require more memory get smaller batch sizes
|
209 |
+
if image_backbone in ["swin_b", "convnext_base"] or audio_backbone in ["transformer", "bidirectional_lstm"]:
|
210 |
+
adjusted_batch_size = max(4, batch_size // 2) # At least batch size of 4, but reduce by half if needed
|
211 |
+
print(f"\033[92mINFO\033[0m: Adjusted batch size to {adjusted_batch_size} for larger model")
|
212 |
+
|
213 |
+
# Create dataset
|
214 |
+
dataset = WatermelonDataset(data_dir)
|
215 |
+
n_samples = len(dataset)
|
216 |
+
|
217 |
+
# Split dataset
|
218 |
+
train_size = int(0.7 * n_samples)
|
219 |
+
val_size = int(0.2 * n_samples)
|
220 |
+
test_size = n_samples - train_size - val_size
|
221 |
+
|
222 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
223 |
+
dataset, [train_size, val_size, test_size]
|
224 |
+
)
|
225 |
+
|
226 |
+
train_loader = DataLoader(train_dataset, batch_size=adjusted_batch_size, shuffle=True)
|
227 |
+
val_loader = DataLoader(val_dataset, batch_size=adjusted_batch_size, shuffle=False)
|
228 |
+
test_loader = DataLoader(test_dataset, batch_size=adjusted_batch_size, shuffle=False)
|
229 |
+
|
230 |
+
# Initialize model
|
231 |
+
model = WatermelonModelModular(image_backbone, audio_backbone, audio_hidden_size).to(device)
|
232 |
+
|
233 |
+
# Loss function and optimizer
|
234 |
+
criterion = torch.nn.MSELoss()
|
235 |
+
mae_criterion = torch.nn.L1Loss() # For MAE evaluation
|
236 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
237 |
+
|
238 |
+
print(f"\033[92mINFO\033[0m: Evaluating model with {image_backbone} (image) and {audio_backbone} (audio)")
|
239 |
+
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
|
240 |
+
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
|
241 |
+
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
|
242 |
+
print(f"\033[92mINFO\033[0m: Batch size: {adjusted_batch_size}")
|
243 |
+
|
244 |
+
# Training loop
|
245 |
+
print(f"\033[92mINFO\033[0m: Training for evaluation...")
|
246 |
+
model.train()
|
247 |
+
running_loss = 0.0
|
248 |
+
|
249 |
+
# Wrap with tqdm for progress visualization
|
250 |
+
train_iterator = tqdm(train_loader, desc="Training")
|
251 |
+
|
252 |
+
for i, (mfcc, image, label) in enumerate(train_iterator):
|
253 |
+
try:
|
254 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
255 |
+
|
256 |
+
optimizer.zero_grad()
|
257 |
+
output = model(mfcc, image)
|
258 |
+
label = label.view(-1, 1).float()
|
259 |
+
loss = criterion(output, label)
|
260 |
+
loss.backward()
|
261 |
+
optimizer.step()
|
262 |
+
|
263 |
+
running_loss += loss.item()
|
264 |
+
train_iterator.set_postfix({"Loss": f"{loss.item():.4f}"})
|
265 |
+
|
266 |
+
# Clear memory after each batch
|
267 |
+
if device.type == 'cuda':
|
268 |
+
del mfcc, image, label, output, loss
|
269 |
+
torch.cuda.empty_cache()
|
270 |
+
|
271 |
+
except Exception as e:
|
272 |
+
print(f"\033[91mERR!\033[0m: Error in training batch {i}: {e}")
|
273 |
+
# Clear memory in case of error
|
274 |
+
if device.type == 'cuda':
|
275 |
+
torch.cuda.empty_cache()
|
276 |
+
continue
|
277 |
+
|
278 |
+
# Validation phase
|
279 |
+
print(f"\033[92mINFO\033[0m: Validating...")
|
280 |
+
model.eval()
|
281 |
+
val_loss = 0.0
|
282 |
+
val_mae = 0.0
|
283 |
+
|
284 |
+
val_iterator = tqdm(val_loader, desc="Validation")
|
285 |
+
|
286 |
+
with torch.no_grad():
|
287 |
+
for i, (mfcc, image, label) in enumerate(val_iterator):
|
288 |
+
try:
|
289 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
290 |
+
output = model(mfcc, image)
|
291 |
+
label = label.view(-1, 1).float()
|
292 |
+
|
293 |
+
# Calculate MSE loss
|
294 |
+
loss = criterion(output, label)
|
295 |
+
val_loss += loss.item()
|
296 |
+
|
297 |
+
# Calculate MAE
|
298 |
+
mae = mae_criterion(output, label)
|
299 |
+
val_mae += mae.item()
|
300 |
+
|
301 |
+
val_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
|
302 |
+
|
303 |
+
# Clear memory after each batch
|
304 |
+
if device.type == 'cuda':
|
305 |
+
del mfcc, image, label, output, loss, mae
|
306 |
+
torch.cuda.empty_cache()
|
307 |
+
|
308 |
+
except Exception as e:
|
309 |
+
print(f"\033[91mERR!\033[0m: Error in validation batch {i}: {e}")
|
310 |
+
# Clear memory in case of error
|
311 |
+
if device.type == 'cuda':
|
312 |
+
torch.cuda.empty_cache()
|
313 |
+
continue
|
314 |
+
|
315 |
+
avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
|
316 |
+
avg_val_mae = val_mae / len(val_loader) if len(val_loader) > 0 else float('inf')
|
317 |
+
|
318 |
+
# Test phase
|
319 |
+
print(f"\033[92mINFO\033[0m: Testing...")
|
320 |
+
model.eval()
|
321 |
+
test_loss = 0.0
|
322 |
+
test_mae = 0.0
|
323 |
+
|
324 |
+
test_iterator = tqdm(test_loader, desc="Testing")
|
325 |
+
|
326 |
+
with torch.no_grad():
|
327 |
+
for i, (mfcc, image, label) in enumerate(test_iterator):
|
328 |
+
try:
|
329 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
330 |
+
output = model(mfcc, image)
|
331 |
+
label = label.view(-1, 1).float()
|
332 |
+
|
333 |
+
# Calculate MSE loss
|
334 |
+
loss = criterion(output, label)
|
335 |
+
test_loss += loss.item()
|
336 |
+
|
337 |
+
# Calculate MAE
|
338 |
+
mae = mae_criterion(output, label)
|
339 |
+
test_mae += mae.item()
|
340 |
+
|
341 |
+
test_iterator.set_postfix({"MSE": f"{loss.item():.4f}", "MAE": f"{mae.item():.4f}"})
|
342 |
+
|
343 |
+
# Clear memory after each batch
|
344 |
+
if device.type == 'cuda':
|
345 |
+
del mfcc, image, label, output, loss, mae
|
346 |
+
torch.cuda.empty_cache()
|
347 |
+
|
348 |
+
except Exception as e:
|
349 |
+
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
|
350 |
+
# Clear memory in case of error
|
351 |
+
if device.type == 'cuda':
|
352 |
+
torch.cuda.empty_cache()
|
353 |
+
continue
|
354 |
+
|
355 |
+
avg_test_loss = test_loss / len(test_loader) if len(test_loader) > 0 else float('inf')
|
356 |
+
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
|
357 |
+
|
358 |
+
results = {
|
359 |
+
"image_backbone": image_backbone,
|
360 |
+
"audio_backbone": audio_backbone,
|
361 |
+
"validation_mse": avg_val_loss,
|
362 |
+
"validation_mae": avg_val_mae,
|
363 |
+
"test_mse": avg_test_loss,
|
364 |
+
"test_mae": avg_test_mae
|
365 |
+
}
|
366 |
+
|
367 |
+
print(f"\033[92mINFO\033[0m: Evaluation Results:")
|
368 |
+
print(f"Image Backbone: {image_backbone}")
|
369 |
+
print(f"Audio Backbone: {audio_backbone}")
|
370 |
+
print(f"Validation MSE: {avg_val_loss:.4f}")
|
371 |
+
print(f"Validation MAE: {avg_val_mae:.4f}")
|
372 |
+
print(f"Test MSE: {avg_test_loss:.4f}")
|
373 |
+
print(f"Test MAE: {avg_test_mae:.4f}")
|
374 |
+
|
375 |
+
# Save model if save_model_dir is provided
|
376 |
+
if save_model_dir:
|
377 |
+
os.makedirs(save_model_dir, exist_ok=True)
|
378 |
+
model_filename = f"{image_backbone}_{audio_backbone}_model.pt"
|
379 |
+
model_path = os.path.join(save_model_dir, model_filename)
|
380 |
+
torch.save(model.state_dict(), model_path)
|
381 |
+
print(f"\033[92mINFO\033[0m: Model saved to {model_path}")
|
382 |
+
|
383 |
+
# Add model path to results
|
384 |
+
results["model_path"] = model_path
|
385 |
+
|
386 |
+
# Clean up memory before returning
|
387 |
+
if device.type == 'cuda':
|
388 |
+
del model, optimizer, criterion, mae_criterion
|
389 |
+
torch.cuda.empty_cache()
|
390 |
+
|
391 |
+
return results
|
392 |
+
|
393 |
+
|
394 |
+
def evaluate_all_combinations(data_dir, image_backbones=None, audio_backbones=None, save_model_dir="test_models", results_file="backbone_evaluation_results.json"):
|
395 |
+
if image_backbones is None:
|
396 |
+
image_backbones = list(IMAGE_BACKBONES.keys())
|
397 |
+
|
398 |
+
if audio_backbones is None:
|
399 |
+
audio_backbones = list(AUDIO_BACKBONES.keys())
|
400 |
+
|
401 |
+
# Create directory for saving models
|
402 |
+
if save_model_dir:
|
403 |
+
os.makedirs(save_model_dir, exist_ok=True)
|
404 |
+
|
405 |
+
# Load previous results if the file exists
|
406 |
+
results = []
|
407 |
+
evaluated_combinations = set()
|
408 |
+
|
409 |
+
if os.path.exists(results_file):
|
410 |
+
try:
|
411 |
+
with open(results_file, 'r') as f:
|
412 |
+
results = json.load(f)
|
413 |
+
evaluated_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in results}
|
414 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(results)} previous results from {results_file}")
|
415 |
+
except Exception as e:
|
416 |
+
print(f"\033[91mERR!\033[0m: Error loading previous results from {results_file}: {e}")
|
417 |
+
results = []
|
418 |
+
evaluated_combinations = set()
|
419 |
+
else:
|
420 |
+
print(f"\033[93mWARN\033[0m: Results file '{results_file}' does not exist. Starting with empty results.")
|
421 |
+
|
422 |
+
# Create combinations to evaluate, skipping any that have already been evaluated
|
423 |
+
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones
|
424 |
+
if (img, aud) not in evaluated_combinations]
|
425 |
+
|
426 |
+
if len(combinations) < len(image_backbones) * len(audio_backbones):
|
427 |
+
print(f"\033[92mINFO\033[0m: Skipping {len(evaluated_combinations)} already evaluated combinations")
|
428 |
+
|
429 |
+
print(f"\033[92mINFO\033[0m: Will evaluate {len(combinations)} combinations")
|
430 |
+
|
431 |
+
for image_backbone, audio_backbone in combinations:
|
432 |
+
print(f"\033[92mINFO\033[0m: Evaluating {image_backbone} + {audio_backbone}")
|
433 |
+
try:
|
434 |
+
# Clean GPU memory before each model evaluation
|
435 |
+
if torch.cuda.is_available():
|
436 |
+
torch.cuda.empty_cache()
|
437 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
|
438 |
+
# Print memory usage for debugging
|
439 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
440 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
441 |
+
|
442 |
+
result = evaluate_model(data_dir, image_backbone, audio_backbone, save_model_dir=save_model_dir)
|
443 |
+
results.append(result)
|
444 |
+
|
445 |
+
# Save results after each evaluation
|
446 |
+
save_results(results, results_file)
|
447 |
+
print(f"\033[92mINFO\033[0m: Updated results saved to {results_file}")
|
448 |
+
|
449 |
+
# Force garbage collection to free memory
|
450 |
+
import gc
|
451 |
+
gc.collect()
|
452 |
+
if torch.cuda.is_available():
|
453 |
+
torch.cuda.empty_cache()
|
454 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
|
455 |
+
# Print memory usage for debugging
|
456 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
457 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
458 |
+
|
459 |
+
except Exception as e:
|
460 |
+
print(f"\033[91mERR!\033[0m: Error evaluating {image_backbone} + {audio_backbone}: {e}")
|
461 |
+
print(f"\033[91mERR!\033[0m: To continue from this point, use --start_from={image_backbone}:{audio_backbone}")
|
462 |
+
|
463 |
+
# Force garbage collection to free memory even if there's an error
|
464 |
+
import gc
|
465 |
+
gc.collect()
|
466 |
+
if torch.cuda.is_available():
|
467 |
+
torch.cuda.empty_cache()
|
468 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
|
469 |
+
|
470 |
+
continue
|
471 |
+
|
472 |
+
# Sort results by test MAE (ascending)
|
473 |
+
results.sort(key=lambda x: x["test_mae"])
|
474 |
+
|
475 |
+
# Save final sorted results
|
476 |
+
save_results(results, results_file)
|
477 |
+
|
478 |
+
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
|
479 |
+
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
|
480 |
+
print("="*60)
|
481 |
+
|
482 |
+
for result in results:
|
483 |
+
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
|
484 |
+
|
485 |
+
return results
|
486 |
+
|
487 |
+
|
488 |
+
def save_results(results, filename="backbone_evaluation_results.json"):
|
489 |
+
"""Save evaluation results to a JSON file."""
|
490 |
+
with open(filename, 'w') as f:
|
491 |
+
json.dump(results, f, indent=4)
|
492 |
+
print(f"\033[92mINFO\033[0m: Results saved to {filename}")
|
493 |
+
|
494 |
+
|
495 |
+
if __name__ == "__main__":
|
496 |
+
import argparse
|
497 |
+
|
498 |
+
parser = argparse.ArgumentParser(description="Evaluate Different Backbones for Watermelon Sweetness Prediction")
|
499 |
+
parser.add_argument(
|
500 |
+
"--data_dir",
|
501 |
+
type=str,
|
502 |
+
default="../cleaned",
|
503 |
+
help="Path to the cleaned dataset directory"
|
504 |
+
)
|
505 |
+
parser.add_argument(
|
506 |
+
"--image_backbone",
|
507 |
+
type=str,
|
508 |
+
default=None,
|
509 |
+
help="Specific image backbone to evaluate (leave empty to evaluate all available)"
|
510 |
+
)
|
511 |
+
parser.add_argument(
|
512 |
+
"--audio_backbone",
|
513 |
+
type=str,
|
514 |
+
default=None,
|
515 |
+
help="Specific audio backbone to evaluate (leave empty to evaluate all available)"
|
516 |
+
)
|
517 |
+
parser.add_argument(
|
518 |
+
"--evaluate_all",
|
519 |
+
action="store_true",
|
520 |
+
help="Evaluate all combinations of backbones"
|
521 |
+
)
|
522 |
+
parser.add_argument(
|
523 |
+
"--start_from",
|
524 |
+
type=str,
|
525 |
+
default=None,
|
526 |
+
help="Start evaluation from a specific combination, format: 'image_backbone:audio_backbone'"
|
527 |
+
)
|
528 |
+
parser.add_argument(
|
529 |
+
"--prioritize_efficient",
|
530 |
+
action="store_true",
|
531 |
+
help="Prioritize more efficient models first to avoid memory issues"
|
532 |
+
)
|
533 |
+
parser.add_argument(
|
534 |
+
"--results_file",
|
535 |
+
type=str,
|
536 |
+
default="backbone_evaluation_results.json",
|
537 |
+
help="File to save the evaluation results"
|
538 |
+
)
|
539 |
+
parser.add_argument(
|
540 |
+
"--load_previous_results",
|
541 |
+
action="store_true",
|
542 |
+
help="Load previous results from results_file if it exists"
|
543 |
+
)
|
544 |
+
parser.add_argument(
|
545 |
+
"--model_dir",
|
546 |
+
type=str,
|
547 |
+
default="test_models",
|
548 |
+
help="Directory to save model checkpoints"
|
549 |
+
)
|
550 |
+
|
551 |
+
args = parser.parse_args()
|
552 |
+
|
553 |
+
# Create model directory if it doesn't exist
|
554 |
+
if args.model_dir:
|
555 |
+
os.makedirs(args.model_dir, exist_ok=True)
|
556 |
+
|
557 |
+
print(f"\033[92mINFO\033[0m: === Available Image Backbones ===")
|
558 |
+
for name in IMAGE_BACKBONES.keys():
|
559 |
+
print(f"- {name}")
|
560 |
+
|
561 |
+
print(f"\033[92mINFO\033[0m: === Available Audio Backbones ===")
|
562 |
+
for name in AUDIO_BACKBONES.keys():
|
563 |
+
print(f"- {name}")
|
564 |
+
|
565 |
+
if args.evaluate_all:
|
566 |
+
evaluate_all_combinations(args.data_dir, results_file=args.results_file, save_model_dir=args.model_dir)
|
567 |
+
elif args.image_backbone and args.audio_backbone:
|
568 |
+
result = evaluate_model(args.data_dir, args.image_backbone, args.audio_backbone, save_model_dir=args.model_dir)
|
569 |
+
save_results([result], args.results_file)
|
570 |
+
else:
|
571 |
+
# Define a default set of backbones to evaluate if not specified
|
572 |
+
if args.prioritize_efficient:
|
573 |
+
# Start with less memory-intensive models
|
574 |
+
image_backbones = ["resnet50", "efficientnet_b0", "resnet101", "efficientnet_b3", "convnext_base", "swin_b"]
|
575 |
+
audio_backbones = ["lstm", "gru", "bidirectional_lstm", "transformer"]
|
576 |
+
else:
|
577 |
+
# Default selection focusing on better performance models
|
578 |
+
image_backbones = ["resnet101", "efficientnet_b3", "swin_b"]
|
579 |
+
audio_backbones = ["lstm", "bidirectional_lstm", "transformer"]
|
580 |
+
|
581 |
+
# Create all combinations
|
582 |
+
combinations = [(img, aud) for img in image_backbones for aud in audio_backbones]
|
583 |
+
|
584 |
+
# Load previous results if requested and file exists
|
585 |
+
previous_results = []
|
586 |
+
previous_combinations = set()
|
587 |
+
if args.load_previous_results:
|
588 |
+
try:
|
589 |
+
if os.path.exists(args.results_file):
|
590 |
+
with open(args.results_file, 'r') as f:
|
591 |
+
previous_results = json.load(f)
|
592 |
+
previous_combinations = {(r["image_backbone"], r["audio_backbone"]) for r in previous_results}
|
593 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(previous_results)} previous results")
|
594 |
+
else:
|
595 |
+
print(f"\033[93mWARN\033[0m: Results file '{args.results_file}' does not exist. Starting with empty results.")
|
596 |
+
except Exception as e:
|
597 |
+
print(f"\033[91mERR!\033[0m: Error loading previous results: {e}")
|
598 |
+
previous_results = []
|
599 |
+
previous_combinations = set()
|
600 |
+
|
601 |
+
# If starting from a specific point
|
602 |
+
if args.start_from:
|
603 |
+
try:
|
604 |
+
start_img, start_aud = args.start_from.split(':')
|
605 |
+
start_idx = combinations.index((start_img, start_aud))
|
606 |
+
combinations = combinations[start_idx:]
|
607 |
+
print(f"\033[92mINFO\033[0m: Starting from combination: {start_img} (image) + {start_aud} (audio)")
|
608 |
+
except (ValueError, IndexError):
|
609 |
+
print(f"\033[91mERR!\033[0m: Invalid start_from format or combination not found. Format should be 'image_backbone:audio_backbone'")
|
610 |
+
print(f"\033[91mERR!\033[0m: Continuing with all combinations.")
|
611 |
+
|
612 |
+
# Skip combinations that have already been evaluated
|
613 |
+
if previous_combinations:
|
614 |
+
original_count = len(combinations)
|
615 |
+
combinations = [(img, aud) for img, aud in combinations if (img, aud) not in previous_combinations]
|
616 |
+
print(f"\033[92mINFO\033[0m: Skipping {original_count - len(combinations)} already evaluated combinations")
|
617 |
+
|
618 |
+
# Evaluate each combination
|
619 |
+
results = previous_results.copy()
|
620 |
+
|
621 |
+
for img_backbone, audio_backbone in combinations:
|
622 |
+
print(f"\033[92mINFO\033[0m: Evaluating {img_backbone} + {audio_backbone}")
|
623 |
+
try:
|
624 |
+
# Clean GPU memory before each model evaluation
|
625 |
+
if torch.cuda.is_available():
|
626 |
+
torch.cuda.empty_cache()
|
627 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared before evaluation")
|
628 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
629 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
630 |
+
|
631 |
+
result = evaluate_model(args.data_dir, img_backbone, audio_backbone, save_model_dir=args.model_dir)
|
632 |
+
results.append(result)
|
633 |
+
|
634 |
+
# Save results after each evaluation
|
635 |
+
save_results(results, args.results_file)
|
636 |
+
|
637 |
+
# Force garbage collection to free memory
|
638 |
+
import gc
|
639 |
+
gc.collect()
|
640 |
+
if torch.cuda.is_available():
|
641 |
+
torch.cuda.empty_cache()
|
642 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after evaluation")
|
643 |
+
print(f"\033[92mINFO\033[0m: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
644 |
+
print(f"\033[92mINFO\033[0m: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
645 |
+
|
646 |
+
except Exception as e:
|
647 |
+
print(f"\033[91mERR!\033[0m: Error evaluating {img_backbone} + {audio_backbone}: {e}")
|
648 |
+
print(f"\033[91mERR!\033[0m: To continue from this point later, use --start_from={img_backbone}:{audio_backbone}")
|
649 |
+
|
650 |
+
# Force garbage collection to free memory even if there's an error
|
651 |
+
import gc
|
652 |
+
gc.collect()
|
653 |
+
if torch.cuda.is_available():
|
654 |
+
torch.cuda.empty_cache()
|
655 |
+
print(f"\033[92mINFO\033[0m: CUDA memory cleared after error")
|
656 |
+
|
657 |
+
continue
|
658 |
+
|
659 |
+
# Sort results by test MAE (ascending)
|
660 |
+
results.sort(key=lambda x: x["test_mae"])
|
661 |
+
|
662 |
+
# Save final sorted results
|
663 |
+
save_results(results, args.results_file)
|
664 |
+
|
665 |
+
print("\n\033[92mINFO\033[0m: === FINAL RESULTS (Sorted by Test MAE) ===")
|
666 |
+
print(f"{'Image Backbone':<20} {'Audio Backbone':<20} {'Val MAE':<10} {'Test MAE':<10}")
|
667 |
+
print("="*60)
|
668 |
+
|
669 |
+
for result in results:
|
670 |
+
print(f"{result['image_backbone']:<20} {result['audio_backbone']:<20} {result['validation_mae']:<10.4f} {result['test_mae']:<10.4f}")
|
models/.nfs00000001a1a17512003726ad
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02999bd33592de717dc1ec8054dc570193074c3f25a7283b3daa580b727b7134
|
3 |
+
size 96095572
|
models/.nfs00000001a234d9cd003726ac
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5df632222fa87e09e635f90e5cce14bdd9fd34b442bf18daaf13e54dedfed132
|
3 |
+
size 96095572
|
models/.nfs00000001a2a11ea9003726ae
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80f999a1540c42ed74491692aa66c3b5a6171f972bdf47c9d52556fe1673c8dd
|
3 |
+
size 96095572
|
models/efficientnet_b0_transformer_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eec8d23f6454198e147db3ff31e497a0fed8cc0fa690f58e2576e9190ca54aa7
|
3 |
+
size 22597034
|
models/efficientnet_b3_transformer_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da70bf6bef70cfa3795e566fd58523a9b41b01c151fb37fd3b255262c2b47451
|
3 |
+
size 49751930
|
models/resnet50_transformer_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cec4fe964defc58fea1f6c26c714c27680a4aa81b131795e8cbeadb6e7be9bd5
|
3 |
+
size 101004668
|
moe_evaluation_results.json
ADDED
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"moe_test_mae": 0.19680618420243262,
|
3 |
+
"moe_test_mse": 0.05606407420709729,
|
4 |
+
"true_labels": [
|
5 |
+
10.5,
|
6 |
+
9.399999618530273,
|
7 |
+
11.600000381469727,
|
8 |
+
8.699999809265137,
|
9 |
+
10.399999618530273,
|
10 |
+
10.800000190734863,
|
11 |
+
11.600000381469727,
|
12 |
+
10.5,
|
13 |
+
11.600000381469727,
|
14 |
+
11.100000381469727,
|
15 |
+
10.399999618530273,
|
16 |
+
10.5,
|
17 |
+
11.0,
|
18 |
+
10.5,
|
19 |
+
10.899999618530273,
|
20 |
+
10.5,
|
21 |
+
11.100000381469727,
|
22 |
+
9.600000381469727,
|
23 |
+
12.699999809265137,
|
24 |
+
10.0,
|
25 |
+
10.300000190734863,
|
26 |
+
10.399999618530273,
|
27 |
+
9.399999618530273,
|
28 |
+
10.800000190734863,
|
29 |
+
10.0,
|
30 |
+
11.600000381469727,
|
31 |
+
10.0,
|
32 |
+
10.399999618530273,
|
33 |
+
9.399999618530273,
|
34 |
+
10.399999618530273,
|
35 |
+
10.300000190734863,
|
36 |
+
9.399999618530273,
|
37 |
+
10.899999618530273,
|
38 |
+
9.0,
|
39 |
+
10.300000190734863,
|
40 |
+
10.899999618530273,
|
41 |
+
11.0,
|
42 |
+
12.699999809265137,
|
43 |
+
10.399999618530273,
|
44 |
+
9.600000381469727,
|
45 |
+
8.699999809265137,
|
46 |
+
10.199999809265137,
|
47 |
+
10.300000190734863,
|
48 |
+
11.600000381469727,
|
49 |
+
9.0,
|
50 |
+
9.0,
|
51 |
+
11.0,
|
52 |
+
8.699999809265137,
|
53 |
+
9.699999809265137,
|
54 |
+
10.399999618530273,
|
55 |
+
10.0,
|
56 |
+
11.600000381469727,
|
57 |
+
9.399999618530273,
|
58 |
+
9.0,
|
59 |
+
10.300000190734863,
|
60 |
+
10.5,
|
61 |
+
10.399999618530273,
|
62 |
+
11.0,
|
63 |
+
10.899999618530273,
|
64 |
+
9.399999618530273,
|
65 |
+
8.699999809265137,
|
66 |
+
10.300000190734863,
|
67 |
+
9.699999809265137,
|
68 |
+
10.300000190734863,
|
69 |
+
9.399999618530273,
|
70 |
+
10.300000190734863,
|
71 |
+
9.399999618530273,
|
72 |
+
10.0,
|
73 |
+
10.399999618530273,
|
74 |
+
10.199999809265137,
|
75 |
+
11.0,
|
76 |
+
12.699999809265137,
|
77 |
+
12.699999809265137,
|
78 |
+
10.0,
|
79 |
+
11.0,
|
80 |
+
9.0,
|
81 |
+
10.0,
|
82 |
+
10.5,
|
83 |
+
11.600000381469727,
|
84 |
+
9.399999618530273,
|
85 |
+
10.0,
|
86 |
+
11.0,
|
87 |
+
11.100000381469727,
|
88 |
+
10.899999618530273,
|
89 |
+
9.399999618530273,
|
90 |
+
10.300000190734863,
|
91 |
+
9.399999618530273,
|
92 |
+
8.699999809265137,
|
93 |
+
10.0,
|
94 |
+
12.699999809265137,
|
95 |
+
12.699999809265137,
|
96 |
+
9.699999809265137,
|
97 |
+
9.399999618530273,
|
98 |
+
11.0,
|
99 |
+
9.399999618530273,
|
100 |
+
9.0,
|
101 |
+
11.100000381469727,
|
102 |
+
10.300000190734863,
|
103 |
+
10.300000190734863,
|
104 |
+
10.300000190734863,
|
105 |
+
10.0,
|
106 |
+
9.399999618530273,
|
107 |
+
9.399999618530273,
|
108 |
+
10.899999618530273,
|
109 |
+
11.0,
|
110 |
+
9.699999809265137,
|
111 |
+
12.699999809265137,
|
112 |
+
10.5,
|
113 |
+
11.0,
|
114 |
+
10.899999618530273,
|
115 |
+
12.699999809265137,
|
116 |
+
10.899999618530273,
|
117 |
+
11.0,
|
118 |
+
10.300000190734863,
|
119 |
+
11.0,
|
120 |
+
9.699999809265137,
|
121 |
+
10.300000190734863,
|
122 |
+
10.300000190734863,
|
123 |
+
10.199999809265137,
|
124 |
+
10.199999809265137,
|
125 |
+
10.899999618530273,
|
126 |
+
10.5,
|
127 |
+
11.0,
|
128 |
+
8.699999809265137,
|
129 |
+
9.699999809265137,
|
130 |
+
12.699999809265137,
|
131 |
+
11.600000381469727,
|
132 |
+
10.899999618530273,
|
133 |
+
11.0,
|
134 |
+
9.399999618530273,
|
135 |
+
10.300000190734863,
|
136 |
+
12.699999809265137,
|
137 |
+
10.199999809265137,
|
138 |
+
10.199999809265137,
|
139 |
+
10.800000190734863,
|
140 |
+
8.699999809265137,
|
141 |
+
9.0,
|
142 |
+
11.0,
|
143 |
+
9.399999618530273,
|
144 |
+
10.800000190734863,
|
145 |
+
11.100000381469727,
|
146 |
+
11.100000381469727,
|
147 |
+
10.199999809265137,
|
148 |
+
9.399999618530273,
|
149 |
+
10.199999809265137,
|
150 |
+
10.199999809265137,
|
151 |
+
9.399999618530273,
|
152 |
+
10.899999618530273,
|
153 |
+
10.199999809265137,
|
154 |
+
11.100000381469727,
|
155 |
+
11.600000381469727,
|
156 |
+
8.699999809265137,
|
157 |
+
11.600000381469727,
|
158 |
+
10.199999809265137,
|
159 |
+
9.399999618530273,
|
160 |
+
9.699999809265137,
|
161 |
+
9.399999618530273
|
162 |
+
],
|
163 |
+
"moe_predictions": [
|
164 |
+
10.906482696533203,
|
165 |
+
9.413387298583984,
|
166 |
+
11.58445930480957,
|
167 |
+
8.627098083496094,
|
168 |
+
10.55517578125,
|
169 |
+
10.969362258911133,
|
170 |
+
11.596641540527344,
|
171 |
+
10.598587036132812,
|
172 |
+
11.712945938110352,
|
173 |
+
11.415390968322754,
|
174 |
+
10.500967979431152,
|
175 |
+
10.939116477966309,
|
176 |
+
11.23089599609375,
|
177 |
+
10.928877830505371,
|
178 |
+
11.180931091308594,
|
179 |
+
10.805574417114258,
|
180 |
+
11.44560432434082,
|
181 |
+
9.797750473022461,
|
182 |
+
12.00424575805664,
|
183 |
+
9.924805641174316,
|
184 |
+
10.419149398803711,
|
185 |
+
10.459878921508789,
|
186 |
+
9.774242401123047,
|
187 |
+
10.985288619995117,
|
188 |
+
10.047812461853027,
|
189 |
+
11.745304107666016,
|
190 |
+
10.191004753112793,
|
191 |
+
10.527164459228516,
|
192 |
+
9.581968307495117,
|
193 |
+
10.483012199401855,
|
194 |
+
10.368606567382812,
|
195 |
+
9.450727462768555,
|
196 |
+
11.197010040283203,
|
197 |
+
9.173027038574219,
|
198 |
+
10.50676441192627,
|
199 |
+
11.195816040039062,
|
200 |
+
11.227279663085938,
|
201 |
+
13.106525421142578,
|
202 |
+
10.4664945602417,
|
203 |
+
9.891031265258789,
|
204 |
+
8.75540542602539,
|
205 |
+
10.572815895080566,
|
206 |
+
10.214585304260254,
|
207 |
+
12.000329971313477,
|
208 |
+
8.887301445007324,
|
209 |
+
8.929031372070312,
|
210 |
+
11.054266929626465,
|
211 |
+
8.85447883605957,
|
212 |
+
9.515145301818848,
|
213 |
+
10.480228424072266,
|
214 |
+
10.193933486938477,
|
215 |
+
11.7305908203125,
|
216 |
+
9.437666893005371,
|
217 |
+
9.13387680053711,
|
218 |
+
10.629348754882812,
|
219 |
+
10.703892707824707,
|
220 |
+
10.539461135864258,
|
221 |
+
11.135326385498047,
|
222 |
+
11.19705867767334,
|
223 |
+
9.558942794799805,
|
224 |
+
8.898516654968262,
|
225 |
+
10.628425598144531,
|
226 |
+
9.657480239868164,
|
227 |
+
10.513351440429688,
|
228 |
+
9.459192276000977,
|
229 |
+
10.358184814453125,
|
230 |
+
9.432706832885742,
|
231 |
+
10.078161239624023,
|
232 |
+
10.572355270385742,
|
233 |
+
10.58112907409668,
|
234 |
+
10.910698890686035,
|
235 |
+
13.053973197937012,
|
236 |
+
12.972726821899414,
|
237 |
+
10.170805931091309,
|
238 |
+
11.225208282470703,
|
239 |
+
8.872610092163086,
|
240 |
+
10.091118812561035,
|
241 |
+
10.724177360534668,
|
242 |
+
11.729219436645508,
|
243 |
+
9.66834545135498,
|
244 |
+
10.027229309082031,
|
245 |
+
11.232885360717773,
|
246 |
+
11.518696784973145,
|
247 |
+
11.261479377746582,
|
248 |
+
9.523242950439453,
|
249 |
+
10.484042167663574,
|
250 |
+
9.522797584533691,
|
251 |
+
8.75236988067627,
|
252 |
+
10.083819389343262,
|
253 |
+
13.073421478271484,
|
254 |
+
13.001571655273438,
|
255 |
+
9.905550003051758,
|
256 |
+
9.318197250366211,
|
257 |
+
11.141549110412598,
|
258 |
+
9.754105567932129,
|
259 |
+
9.013923645019531,
|
260 |
+
11.429242134094238,
|
261 |
+
10.375783920288086,
|
262 |
+
10.526394844055176,
|
263 |
+
10.307140350341797,
|
264 |
+
10.169934272766113,
|
265 |
+
9.429258346557617,
|
266 |
+
9.29328441619873,
|
267 |
+
11.136444091796875,
|
268 |
+
11.040485382080078,
|
269 |
+
9.723966598510742,
|
270 |
+
12.936074256896973,
|
271 |
+
10.913898468017578,
|
272 |
+
11.255935668945312,
|
273 |
+
11.032815933227539,
|
274 |
+
12.95362663269043,
|
275 |
+
10.942233085632324,
|
276 |
+
11.014484405517578,
|
277 |
+
10.47386646270752,
|
278 |
+
11.207697868347168,
|
279 |
+
9.531013488769531,
|
280 |
+
10.512401580810547,
|
281 |
+
10.791257858276367,
|
282 |
+
10.385677337646484,
|
283 |
+
10.393269538879395,
|
284 |
+
11.13322639465332,
|
285 |
+
10.893503189086914,
|
286 |
+
11.24067497253418,
|
287 |
+
8.767911911010742,
|
288 |
+
9.76015853881836,
|
289 |
+
13.095734596252441,
|
290 |
+
11.651636123657227,
|
291 |
+
11.08572006225586,
|
292 |
+
10.958650588989258,
|
293 |
+
9.548912048339844,
|
294 |
+
10.243309020996094,
|
295 |
+
13.102086067199707,
|
296 |
+
10.579414367675781,
|
297 |
+
10.406577110290527,
|
298 |
+
11.255165100097656,
|
299 |
+
8.494292259216309,
|
300 |
+
8.890151023864746,
|
301 |
+
11.146952629089355,
|
302 |
+
9.766341209411621,
|
303 |
+
11.163339614868164,
|
304 |
+
11.502073287963867,
|
305 |
+
11.408285140991211,
|
306 |
+
10.383015632629395,
|
307 |
+
9.54578971862793,
|
308 |
+
10.56948184967041,
|
309 |
+
10.558614730834961,
|
310 |
+
9.794357299804688,
|
311 |
+
10.885274887084961,
|
312 |
+
10.377969741821289,
|
313 |
+
11.410195350646973,
|
314 |
+
11.537992477416992,
|
315 |
+
8.826037406921387,
|
316 |
+
12.070415496826172,
|
317 |
+
10.559798240661621,
|
318 |
+
9.605077743530273,
|
319 |
+
9.737533569335938,
|
320 |
+
9.520374298095703
|
321 |
+
],
|
322 |
+
"individual_predictions": {
|
323 |
+
"efficientnet_b3_transformer": [
|
324 |
+
10.619565963745117,
|
325 |
+
9.285565376281738,
|
326 |
+
11.017762184143066,
|
327 |
+
8.358080863952637,
|
328 |
+
9.92147159576416,
|
329 |
+
10.68340015411377,
|
330 |
+
11.023524284362793,
|
331 |
+
10.292417526245117,
|
332 |
+
10.513864517211914,
|
333 |
+
10.958821296691895,
|
334 |
+
10.322061538696289,
|
335 |
+
10.383071899414062,
|
336 |
+
10.330121040344238,
|
337 |
+
10.344510078430176,
|
338 |
+
11.309442520141602,
|
339 |
+
10.321882247924805,
|
340 |
+
10.974185943603516,
|
341 |
+
9.367315292358398,
|
342 |
+
11.474529266357422,
|
343 |
+
9.296891212463379,
|
344 |
+
10.27892780303955,
|
345 |
+
10.14356803894043,
|
346 |
+
9.155308723449707,
|
347 |
+
10.249421119689941,
|
348 |
+
9.534292221069336,
|
349 |
+
11.197205543518066,
|
350 |
+
9.988767623901367,
|
351 |
+
10.485107421875,
|
352 |
+
9.040623664855957,
|
353 |
+
10.171326637268066,
|
354 |
+
10.153056144714355,
|
355 |
+
9.17545223236084,
|
356 |
+
10.604523658752441,
|
357 |
+
8.7711763381958,
|
358 |
+
10.127464294433594,
|
359 |
+
11.29480266571045,
|
360 |
+
10.326626777648926,
|
361 |
+
13.54947566986084,
|
362 |
+
10.142123222351074,
|
363 |
+
9.914827346801758,
|
364 |
+
7.935253620147705,
|
365 |
+
10.513096809387207,
|
366 |
+
9.79228687286377,
|
367 |
+
11.721403121948242,
|
368 |
+
7.996966361999512,
|
369 |
+
8.011720657348633,
|
370 |
+
10.551737785339355,
|
371 |
+
8.663973808288574,
|
372 |
+
8.74413776397705,
|
373 |
+
10.276195526123047,
|
374 |
+
10.136805534362793,
|
375 |
+
11.221556663513184,
|
376 |
+
8.912840843200684,
|
377 |
+
8.619383811950684,
|
378 |
+
10.178643226623535,
|
379 |
+
10.311914443969727,
|
380 |
+
10.487189292907715,
|
381 |
+
10.548056602478027,
|
382 |
+
11.258485794067383,
|
383 |
+
9.288726806640625,
|
384 |
+
8.140922546386719,
|
385 |
+
10.216073989868164,
|
386 |
+
9.068129539489746,
|
387 |
+
10.33917236328125,
|
388 |
+
9.11395263671875,
|
389 |
+
10.140262603759766,
|
390 |
+
8.864439010620117,
|
391 |
+
9.560175895690918,
|
392 |
+
10.1554594039917,
|
393 |
+
10.011631965637207,
|
394 |
+
10.838635444641113,
|
395 |
+
13.890799522399902,
|
396 |
+
13.743374824523926,
|
397 |
+
10.119439125061035,
|
398 |
+
11.073603630065918,
|
399 |
+
7.99126672744751,
|
400 |
+
10.012906074523926,
|
401 |
+
10.309550285339355,
|
402 |
+
10.537038803100586,
|
403 |
+
9.361739158630371,
|
404 |
+
9.594813346862793,
|
405 |
+
10.32430362701416,
|
406 |
+
11.0283842086792,
|
407 |
+
11.271435737609863,
|
408 |
+
9.267289161682129,
|
409 |
+
10.143651962280273,
|
410 |
+
9.201630592346191,
|
411 |
+
8.489853858947754,
|
412 |
+
9.663308143615723,
|
413 |
+
13.539351463317871,
|
414 |
+
13.890753746032715,
|
415 |
+
9.300865173339844,
|
416 |
+
8.978877067565918,
|
417 |
+
10.455121994018555,
|
418 |
+
9.145268440246582,
|
419 |
+
8.390588760375977,
|
420 |
+
10.97396183013916,
|
421 |
+
10.023279190063477,
|
422 |
+
10.194899559020996,
|
423 |
+
9.974883079528809,
|
424 |
+
10.101761817932129,
|
425 |
+
9.511059761047363,
|
426 |
+
8.89189624786377,
|
427 |
+
10.77907657623291,
|
428 |
+
10.7083158493042,
|
429 |
+
9.067532539367676,
|
430 |
+
13.406800270080566,
|
431 |
+
10.60212516784668,
|
432 |
+
10.704161643981934,
|
433 |
+
11.133363723754883,
|
434 |
+
13.293631553649902,
|
435 |
+
9.996685981750488,
|
436 |
+
10.766114234924316,
|
437 |
+
10.15234088897705,
|
438 |
+
11.180027961730957,
|
439 |
+
8.875227928161621,
|
440 |
+
10.376603126525879,
|
441 |
+
10.074305534362793,
|
442 |
+
10.001667022705078,
|
443 |
+
10.027312278747559,
|
444 |
+
10.606922149658203,
|
445 |
+
10.565585136413574,
|
446 |
+
10.699769020080566,
|
447 |
+
8.507576942443848,
|
448 |
+
9.084380149841309,
|
449 |
+
13.500945091247559,
|
450 |
+
11.240296363830566,
|
451 |
+
10.65023136138916,
|
452 |
+
10.248372077941895,
|
453 |
+
9.269180297851562,
|
454 |
+
9.840892791748047,
|
455 |
+
13.547538757324219,
|
456 |
+
9.992758750915527,
|
457 |
+
10.026358604431152,
|
458 |
+
10.71567440032959,
|
459 |
+
8.320480346679688,
|
460 |
+
8.000975608825684,
|
461 |
+
10.548954963684082,
|
462 |
+
9.176098823547363,
|
463 |
+
11.098072052001953,
|
464 |
+
11.02483081817627,
|
465 |
+
11.12319278717041,
|
466 |
+
9.996392250061035,
|
467 |
+
9.263312339782715,
|
468 |
+
10.517735481262207,
|
469 |
+
9.8799409866333,
|
470 |
+
9.319127082824707,
|
471 |
+
9.990796089172363,
|
472 |
+
9.982155799865723,
|
473 |
+
11.105603218078613,
|
474 |
+
10.747210502624512,
|
475 |
+
8.343344688415527,
|
476 |
+
11.73001480102539,
|
477 |
+
10.511062622070312,
|
478 |
+
9.331645965576172,
|
479 |
+
9.131060600280762,
|
480 |
+
8.956952095031738
|
481 |
+
],
|
482 |
+
"efficientnet_b0_transformer": [
|
483 |
+
11.040512084960938,
|
484 |
+
9.555410385131836,
|
485 |
+
11.689399719238281,
|
486 |
+
8.434002876281738,
|
487 |
+
11.386773109436035,
|
488 |
+
10.940624237060547,
|
489 |
+
11.708887100219727,
|
490 |
+
11.056541442871094,
|
491 |
+
12.392988204956055,
|
492 |
+
11.619367599487305,
|
493 |
+
10.591476440429688,
|
494 |
+
11.15828800201416,
|
495 |
+
11.810995101928711,
|
496 |
+
11.26023006439209,
|
497 |
+
11.246732711791992,
|
498 |
+
11.448994636535645,
|
499 |
+
11.935430526733398,
|
500 |
+
10.085470199584961,
|
501 |
+
12.768455505371094,
|
502 |
+
10.39224910736084,
|
503 |
+
10.590924263000488,
|
504 |
+
10.642997741699219,
|
505 |
+
9.948995590209961,
|
506 |
+
11.38804817199707,
|
507 |
+
10.38807487487793,
|
508 |
+
11.55557632446289,
|
509 |
+
10.514514923095703,
|
510 |
+
10.37149429321289,
|
511 |
+
9.95881462097168,
|
512 |
+
10.645825386047363,
|
513 |
+
10.480897903442383,
|
514 |
+
9.64439868927002,
|
515 |
+
11.213277816772461,
|
516 |
+
9.551204681396484,
|
517 |
+
10.929215431213379,
|
518 |
+
11.268585205078125,
|
519 |
+
11.799053192138672,
|
520 |
+
12.975137710571289,
|
521 |
+
10.657550811767578,
|
522 |
+
9.907003402709961,
|
523 |
+
9.108478546142578,
|
524 |
+
10.350242614746094,
|
525 |
+
10.475027084350586,
|
526 |
+
12.249593734741211,
|
527 |
+
9.311214447021484,
|
528 |
+
9.402128219604492,
|
529 |
+
11.460792541503906,
|
530 |
+
8.638538360595703,
|
531 |
+
10.098196029663086,
|
532 |
+
10.429000854492188,
|
533 |
+
10.63322639465332,
|
534 |
+
11.521190643310547,
|
535 |
+
9.934067726135254,
|
536 |
+
9.390719413757324,
|
537 |
+
10.85897445678711,
|
538 |
+
10.96368408203125,
|
539 |
+
10.440620422363281,
|
540 |
+
11.39995002746582,
|
541 |
+
11.138040542602539,
|
542 |
+
9.738420486450195,
|
543 |
+
9.13027286529541,
|
544 |
+
10.834165573120117,
|
545 |
+
9.734615325927734,
|
546 |
+
10.535043716430664,
|
547 |
+
9.7576904296875,
|
548 |
+
10.504064559936523,
|
549 |
+
9.726502418518066,
|
550 |
+
10.391711235046387,
|
551 |
+
10.526286125183105,
|
552 |
+
10.450986862182617,
|
553 |
+
10.732028007507324,
|
554 |
+
13.047806739807129,
|
555 |
+
12.901583671569824,
|
556 |
+
10.609762191772461,
|
557 |
+
11.112765312194824,
|
558 |
+
9.227752685546875,
|
559 |
+
10.403764724731445,
|
560 |
+
10.97991943359375,
|
561 |
+
12.400298118591309,
|
562 |
+
9.740009307861328,
|
563 |
+
10.546162605285645,
|
564 |
+
11.811308860778809,
|
565 |
+
12.024316787719727,
|
566 |
+
11.304412841796875,
|
567 |
+
9.642568588256836,
|
568 |
+
10.770721435546875,
|
569 |
+
9.673535346984863,
|
570 |
+
8.692492485046387,
|
571 |
+
10.140533447265625,
|
572 |
+
13.103691101074219,
|
573 |
+
12.987236022949219,
|
574 |
+
9.978914260864258,
|
575 |
+
9.647960662841797,
|
576 |
+
11.465564727783203,
|
577 |
+
9.91793155670166,
|
578 |
+
8.99271011352539,
|
579 |
+
11.874197959899902,
|
580 |
+
10.875059127807617,
|
581 |
+
10.751541137695312,
|
582 |
+
10.586625099182129,
|
583 |
+
10.616861343383789,
|
584 |
+
9.251531600952148,
|
585 |
+
9.575355529785156,
|
586 |
+
11.49870777130127,
|
587 |
+
11.352771759033203,
|
588 |
+
9.970162391662598,
|
589 |
+
12.869828224182129,
|
590 |
+
11.021011352539062,
|
591 |
+
11.830097198486328,
|
592 |
+
10.895241737365723,
|
593 |
+
13.477546691894531,
|
594 |
+
11.435956001281738,
|
595 |
+
11.21767807006836,
|
596 |
+
10.8616361618042,
|
597 |
+
11.25930404663086,
|
598 |
+
9.386629104614258,
|
599 |
+
10.510151863098145,
|
600 |
+
11.104487419128418,
|
601 |
+
10.017858505249023,
|
602 |
+
10.365488052368164,
|
603 |
+
11.206178665161133,
|
604 |
+
11.027682304382324,
|
605 |
+
11.81328010559082,
|
606 |
+
8.614967346191406,
|
607 |
+
10.088481903076172,
|
608 |
+
12.978555679321289,
|
609 |
+
11.964248657226562,
|
610 |
+
11.287935256958008,
|
611 |
+
11.514422416687012,
|
612 |
+
9.758452415466309,
|
613 |
+
10.500945091247559,
|
614 |
+
12.95924186706543,
|
615 |
+
10.438175201416016,
|
616 |
+
10.364145278930664,
|
617 |
+
11.490489959716797,
|
618 |
+
8.45285415649414,
|
619 |
+
9.380582809448242,
|
620 |
+
11.404769897460938,
|
621 |
+
10.42972183227539,
|
622 |
+
11.568924903869629,
|
623 |
+
11.746879577636719,
|
624 |
+
11.68482780456543,
|
625 |
+
10.019561767578125,
|
626 |
+
9.662923812866211,
|
627 |
+
10.360588073730469,
|
628 |
+
10.901131629943848,
|
629 |
+
10.128849029541016,
|
630 |
+
11.287601470947266,
|
631 |
+
10.017107009887695,
|
632 |
+
11.725995063781738,
|
633 |
+
11.726645469665527,
|
634 |
+
8.865287780761719,
|
635 |
+
12.030455589294434,
|
636 |
+
10.348114013671875,
|
637 |
+
9.747005462646484,
|
638 |
+
9.905638694763184,
|
639 |
+
9.855661392211914
|
640 |
+
],
|
641 |
+
"resnet50_transformer": [
|
642 |
+
11.059370040893555,
|
643 |
+
9.399184226989746,
|
644 |
+
12.046213150024414,
|
645 |
+
9.089208602905273,
|
646 |
+
10.357281684875488,
|
647 |
+
11.284062385559082,
|
648 |
+
12.057510375976562,
|
649 |
+
10.44680118560791,
|
650 |
+
12.231982231140137,
|
651 |
+
11.667984008789062,
|
652 |
+
10.58936595916748,
|
653 |
+
11.275989532470703,
|
654 |
+
11.5515718460083,
|
655 |
+
11.181893348693848,
|
656 |
+
10.986615180969238,
|
657 |
+
10.645844459533691,
|
658 |
+
11.427197456359863,
|
659 |
+
9.94046688079834,
|
660 |
+
11.769749641418457,
|
661 |
+
10.08527660369873,
|
662 |
+
10.387595176696777,
|
663 |
+
10.593070030212402,
|
664 |
+
10.218421936035156,
|
665 |
+
11.31839656829834,
|
666 |
+
10.221070289611816,
|
667 |
+
12.48313045501709,
|
668 |
+
10.069729804992676,
|
669 |
+
10.72489070892334,
|
670 |
+
9.746464729309082,
|
671 |
+
10.631884574890137,
|
672 |
+
10.4718656539917,
|
673 |
+
9.532330513000488,
|
674 |
+
11.773228645324707,
|
675 |
+
9.196700096130371,
|
676 |
+
10.46361255645752,
|
677 |
+
11.024060249328613,
|
678 |
+
11.556159019470215,
|
679 |
+
12.794964790344238,
|
680 |
+
10.599808692932129,
|
681 |
+
9.851262092590332,
|
682 |
+
9.222484588623047,
|
683 |
+
10.855106353759766,
|
684 |
+
10.37644100189209,
|
685 |
+
12.02999210357666,
|
686 |
+
9.35372257232666,
|
687 |
+
9.37324333190918,
|
688 |
+
11.150269508361816,
|
689 |
+
9.2609224319458,
|
690 |
+
9.703102111816406,
|
691 |
+
10.735487937927246,
|
692 |
+
9.811766624450684,
|
693 |
+
12.44902515411377,
|
694 |
+
9.46609115600586,
|
695 |
+
9.391528129577637,
|
696 |
+
10.850428581237793,
|
697 |
+
10.836078643798828,
|
698 |
+
10.690573692321777,
|
699 |
+
11.45797348022461,
|
700 |
+
11.194649696350098,
|
701 |
+
9.649679183959961,
|
702 |
+
9.42435359954834,
|
703 |
+
10.835038185119629,
|
704 |
+
10.169693946838379,
|
705 |
+
10.665839195251465,
|
706 |
+
9.50593376159668,
|
707 |
+
10.43022632598877,
|
708 |
+
9.70718002319336,
|
709 |
+
10.282594680786133,
|
710 |
+
11.035321235656738,
|
711 |
+
11.280767440795898,
|
712 |
+
11.161433219909668,
|
713 |
+
12.223311424255371,
|
714 |
+
12.273221015930176,
|
715 |
+
9.783215522766113,
|
716 |
+
11.48925495147705,
|
717 |
+
9.398808479309082,
|
718 |
+
9.856684684753418,
|
719 |
+
10.883062362670898,
|
720 |
+
12.250321388244629,
|
721 |
+
9.903286933898926,
|
722 |
+
9.940712928771973,
|
723 |
+
11.563044548034668,
|
724 |
+
11.503388404846191,
|
725 |
+
11.208588600158691,
|
726 |
+
9.659869194030762,
|
727 |
+
10.537753105163574,
|
728 |
+
9.693224906921387,
|
729 |
+
9.074763298034668,
|
730 |
+
10.447615623474121,
|
731 |
+
12.577223777770996,
|
732 |
+
12.126725196838379,
|
733 |
+
10.436871528625488,
|
734 |
+
9.327754020690918,
|
735 |
+
11.503960609436035,
|
736 |
+
10.199116706848145,
|
737 |
+
9.658470153808594,
|
738 |
+
11.43956470489502,
|
739 |
+
10.229013442993164,
|
740 |
+
10.632741928100586,
|
741 |
+
10.35991096496582,
|
742 |
+
9.791178703308105,
|
743 |
+
9.52518367767334,
|
744 |
+
9.412601470947266,
|
745 |
+
11.131546974182129,
|
746 |
+
11.0603666305542,
|
747 |
+
10.13420295715332,
|
748 |
+
12.53159236907959,
|
749 |
+
11.118557929992676,
|
750 |
+
11.233548164367676,
|
751 |
+
11.069842338562012,
|
752 |
+
12.089702606201172,
|
753 |
+
11.394057273864746,
|
754 |
+
11.059659957885742,
|
755 |
+
10.407622337341309,
|
756 |
+
11.183761596679688,
|
757 |
+
10.331181526184082,
|
758 |
+
10.6504487991333,
|
759 |
+
11.194979667663574,
|
760 |
+
11.137504577636719,
|
761 |
+
10.787008285522461,
|
762 |
+
11.586577415466309,
|
763 |
+
11.08724308013916,
|
764 |
+
11.208975791931152,
|
765 |
+
9.181191444396973,
|
766 |
+
10.107614517211914,
|
767 |
+
12.807703018188477,
|
768 |
+
11.750362396240234,
|
769 |
+
11.31899356842041,
|
770 |
+
11.11315631866455,
|
771 |
+
9.619100570678711,
|
772 |
+
10.388087272644043,
|
773 |
+
12.79947566986084,
|
774 |
+
11.307307243347168,
|
775 |
+
10.82922649383545,
|
776 |
+
11.55932903289795,
|
777 |
+
8.709542274475098,
|
778 |
+
9.288893699645996,
|
779 |
+
11.48713207244873,
|
780 |
+
9.693202018737793,
|
781 |
+
10.82302188873291,
|
782 |
+
11.73450756072998,
|
783 |
+
11.416834831237793,
|
784 |
+
11.133091926574707,
|
785 |
+
9.71113109588623,
|
786 |
+
10.830121040344238,
|
787 |
+
10.894770622253418,
|
788 |
+
9.935094833374023,
|
789 |
+
11.377425193786621,
|
790 |
+
11.13464641571045,
|
791 |
+
11.39898681640625,
|
792 |
+
12.140122413635254,
|
793 |
+
9.269479751586914,
|
794 |
+
12.450774192810059,
|
795 |
+
10.820216178894043,
|
796 |
+
9.736580848693848,
|
797 |
+
10.17590045928955,
|
798 |
+
9.74850845336914
|
799 |
+
]
|
800 |
+
}
|
801 |
+
}
|
templates/.nfs00000001a2893bde003726a5
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
test_moe_model.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import torchvision
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
import sys
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Add parent directory to path to import the preprocess functions
|
12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
+
from preprocess import process_audio_data, process_image_data
|
14 |
+
|
15 |
+
# Import the WatermelonDataset and WatermelonModelModular from the evaluate_backbones.py file
|
16 |
+
from evaluate_backbones import WatermelonDataset, WatermelonModelModular, IMAGE_BACKBONES, AUDIO_BACKBONES
|
17 |
+
|
18 |
+
# Print library versions
|
19 |
+
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
|
20 |
+
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
|
21 |
+
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
|
22 |
+
|
23 |
+
# Device selection
|
24 |
+
device = torch.device(
|
25 |
+
"cuda" if torch.cuda.is_available()
|
26 |
+
else "mps" if torch.backends.mps.is_available()
|
27 |
+
else "cpu"
|
28 |
+
)
|
29 |
+
print(f"\033[92mINFO\033[0m: Using device: {device}")
|
30 |
+
|
31 |
+
# Define the top-performing models based on the previous evaluation
|
32 |
+
TOP_MODELS = [
|
33 |
+
{"image_backbone": "efficientnet_b3", "audio_backbone": "transformer"},
|
34 |
+
{"image_backbone": "efficientnet_b0", "audio_backbone": "transformer"},
|
35 |
+
{"image_backbone": "resnet50", "audio_backbone": "transformer"}
|
36 |
+
]
|
37 |
+
|
38 |
+
# Define class for the MoE model
|
39 |
+
class WatermelonMoEModel(torch.nn.Module):
|
40 |
+
def __init__(self, model_configs, model_dir="test_models", weights=None):
|
41 |
+
"""
|
42 |
+
Mixture of Experts model that combines multiple backbone models.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
model_configs: List of dictionaries with 'image_backbone' and 'audio_backbone' keys
|
46 |
+
model_dir: Directory where model checkpoints are stored
|
47 |
+
weights: Optional list of weights for each model (None for equal weighting)
|
48 |
+
"""
|
49 |
+
super(WatermelonMoEModel, self).__init__()
|
50 |
+
self.models = []
|
51 |
+
self.model_configs = model_configs
|
52 |
+
|
53 |
+
# Load each model
|
54 |
+
for config in model_configs:
|
55 |
+
img_backbone = config["image_backbone"]
|
56 |
+
audio_backbone = config["audio_backbone"]
|
57 |
+
|
58 |
+
# Initialize model
|
59 |
+
model = WatermelonModelModular(img_backbone, audio_backbone)
|
60 |
+
|
61 |
+
# Load weights
|
62 |
+
model_path = os.path.join(model_dir, f"{img_backbone}_{audio_backbone}_model.pt")
|
63 |
+
if os.path.exists(model_path):
|
64 |
+
print(f"\033[92mINFO\033[0m: Loading model {img_backbone}_{audio_backbone} from {model_path}")
|
65 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
66 |
+
else:
|
67 |
+
print(f"\033[91mERR!\033[0m: Model checkpoint not found at {model_path}")
|
68 |
+
continue
|
69 |
+
|
70 |
+
model.to(device)
|
71 |
+
model.eval() # Set to evaluation mode
|
72 |
+
self.models.append(model)
|
73 |
+
|
74 |
+
# Set model weights (uniform by default)
|
75 |
+
if weights:
|
76 |
+
assert len(weights) == len(self.models), "Number of weights must match number of models"
|
77 |
+
self.weights = weights
|
78 |
+
else:
|
79 |
+
self.weights = [1.0 / len(self.models)] * len(self.models)
|
80 |
+
|
81 |
+
print(f"\033[92mINFO\033[0m: Loaded {len(self.models)} models for MoE ensemble")
|
82 |
+
print(f"\033[92mINFO\033[0m: Model weights: {self.weights}")
|
83 |
+
|
84 |
+
def forward(self, mfcc, image):
|
85 |
+
"""
|
86 |
+
Forward pass through the MoE model.
|
87 |
+
Returns the weighted average of all model outputs.
|
88 |
+
"""
|
89 |
+
outputs = []
|
90 |
+
|
91 |
+
# Get outputs from each model
|
92 |
+
with torch.no_grad():
|
93 |
+
for i, model in enumerate(self.models):
|
94 |
+
output = model(mfcc, image)
|
95 |
+
outputs.append(output * self.weights[i])
|
96 |
+
|
97 |
+
# Return weighted average
|
98 |
+
return torch.sum(torch.stack(outputs), dim=0)
|
99 |
+
|
100 |
+
|
101 |
+
def evaluate_moe_model(data_dir, model_dir="test_models", weights=None):
|
102 |
+
"""
|
103 |
+
Evaluate the MoE model on the test set.
|
104 |
+
"""
|
105 |
+
# Load dataset
|
106 |
+
print(f"\033[92mINFO\033[0m: Loading dataset from {data_dir}")
|
107 |
+
dataset = WatermelonDataset(data_dir)
|
108 |
+
n_samples = len(dataset)
|
109 |
+
|
110 |
+
# Split dataset
|
111 |
+
train_size = int(0.7 * n_samples)
|
112 |
+
val_size = int(0.2 * n_samples)
|
113 |
+
test_size = n_samples - train_size - val_size
|
114 |
+
|
115 |
+
_, _, test_dataset = torch.utils.data.random_split(
|
116 |
+
dataset, [train_size, val_size, test_size]
|
117 |
+
)
|
118 |
+
|
119 |
+
# Use a reasonable batch size
|
120 |
+
batch_size = 8
|
121 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
122 |
+
|
123 |
+
# Initialize MoE model
|
124 |
+
moe_model = WatermelonMoEModel(TOP_MODELS, model_dir, weights)
|
125 |
+
moe_model.eval()
|
126 |
+
|
127 |
+
# Evaluation metrics
|
128 |
+
mae_criterion = torch.nn.L1Loss()
|
129 |
+
mse_criterion = torch.nn.MSELoss()
|
130 |
+
|
131 |
+
test_mae = 0.0
|
132 |
+
test_mse = 0.0
|
133 |
+
|
134 |
+
print(f"\033[92mINFO\033[0m: Evaluating MoE model on {len(test_dataset)} test samples")
|
135 |
+
|
136 |
+
# Individual model predictions for analysis
|
137 |
+
individual_predictions = {f"{config['image_backbone']}_{config['audio_backbone']}": []
|
138 |
+
for config in TOP_MODELS}
|
139 |
+
true_labels = []
|
140 |
+
moe_predictions = []
|
141 |
+
|
142 |
+
# Evaluation loop
|
143 |
+
test_iterator = tqdm(test_loader, desc="Testing MoE")
|
144 |
+
|
145 |
+
with torch.no_grad():
|
146 |
+
for i, (mfcc, image, label) in enumerate(test_iterator):
|
147 |
+
try:
|
148 |
+
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
|
149 |
+
|
150 |
+
# Store individual model outputs for analysis
|
151 |
+
for j, model in enumerate(moe_model.models):
|
152 |
+
config = TOP_MODELS[j]
|
153 |
+
model_name = f"{config['image_backbone']}_{config['audio_backbone']}"
|
154 |
+
output = model(mfcc, image)
|
155 |
+
individual_predictions[model_name].extend(output.view(-1).cpu().numpy())
|
156 |
+
|
157 |
+
# Get MoE prediction
|
158 |
+
output = moe_model(mfcc, image)
|
159 |
+
moe_predictions.extend(output.view(-1).cpu().numpy())
|
160 |
+
|
161 |
+
# Store true labels
|
162 |
+
label = label.view(-1, 1).float()
|
163 |
+
true_labels.extend(label.view(-1).cpu().numpy())
|
164 |
+
|
165 |
+
# Calculate metrics
|
166 |
+
mae = mae_criterion(output, label)
|
167 |
+
mse = mse_criterion(output, label)
|
168 |
+
|
169 |
+
test_mae += mae.item()
|
170 |
+
test_mse += mse.item()
|
171 |
+
|
172 |
+
test_iterator.set_postfix({"MAE": f"{mae.item():.4f}", "MSE": f"{mse.item():.4f}"})
|
173 |
+
|
174 |
+
# Clean up memory
|
175 |
+
if device.type == 'cuda':
|
176 |
+
del mfcc, image, label, output, mae, mse
|
177 |
+
torch.cuda.empty_cache()
|
178 |
+
|
179 |
+
except Exception as e:
|
180 |
+
print(f"\033[91mERR!\033[0m: Error in test batch {i}: {e}")
|
181 |
+
if device.type == 'cuda':
|
182 |
+
torch.cuda.empty_cache()
|
183 |
+
continue
|
184 |
+
|
185 |
+
# Calculate average metrics
|
186 |
+
avg_test_mae = test_mae / len(test_loader) if len(test_loader) > 0 else float('inf')
|
187 |
+
avg_test_mse = test_mse / len(test_loader) if len(test_loader) > 0 else float('inf')
|
188 |
+
|
189 |
+
print(f"\n\033[92mINFO\033[0m: === MoE Model Results ===")
|
190 |
+
print(f"Test MAE: {avg_test_mae:.4f}")
|
191 |
+
print(f"Test MSE: {avg_test_mse:.4f}")
|
192 |
+
|
193 |
+
# Compare with individual models
|
194 |
+
print(f"\n\033[92mINFO\033[0m: === Comparison with Individual Models ===")
|
195 |
+
print(f"{'Model':<30} {'Test MAE':<15}")
|
196 |
+
print("="*45)
|
197 |
+
|
198 |
+
# Load previous results
|
199 |
+
results_file = "backbone_evaluation_results.json"
|
200 |
+
if os.path.exists(results_file):
|
201 |
+
with open(results_file, 'r') as f:
|
202 |
+
previous_results = json.load(f)
|
203 |
+
|
204 |
+
# Filter results for our top models
|
205 |
+
for config in TOP_MODELS:
|
206 |
+
img_backbone = config["image_backbone"]
|
207 |
+
audio_backbone = config["audio_backbone"]
|
208 |
+
|
209 |
+
for result in previous_results:
|
210 |
+
if result["image_backbone"] == img_backbone and result["audio_backbone"] == audio_backbone:
|
211 |
+
print(f"{img_backbone}_{audio_backbone:<20} {result['test_mae']:<15.4f}")
|
212 |
+
|
213 |
+
print(f"MoE (Ensemble) {avg_test_mae:<15.4f}")
|
214 |
+
|
215 |
+
# Save results and predictions
|
216 |
+
results = {
|
217 |
+
"moe_test_mae": float(avg_test_mae),
|
218 |
+
"moe_test_mse": float(avg_test_mse),
|
219 |
+
"true_labels": [float(x) for x in true_labels],
|
220 |
+
"moe_predictions": [float(x) for x in moe_predictions],
|
221 |
+
"individual_predictions": {key: [float(x) for x in values]
|
222 |
+
for key, values in individual_predictions.items()}
|
223 |
+
}
|
224 |
+
|
225 |
+
with open("moe_evaluation_results.json", 'w') as f:
|
226 |
+
json.dump(results, f, indent=4)
|
227 |
+
|
228 |
+
print(f"\033[92mINFO\033[0m: Results saved to moe_evaluation_results.json")
|
229 |
+
|
230 |
+
return avg_test_mae, avg_test_mse
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
import argparse
|
235 |
+
|
236 |
+
parser = argparse.ArgumentParser(description="Test Mixture of Experts (MoE) Model for Watermelon Sweetness Prediction")
|
237 |
+
parser.add_argument(
|
238 |
+
"--data_dir",
|
239 |
+
type=str,
|
240 |
+
default="../cleaned",
|
241 |
+
help="Path to the cleaned dataset directory"
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--model_dir",
|
245 |
+
type=str,
|
246 |
+
default="test_models",
|
247 |
+
help="Directory containing model checkpoints"
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--weighting",
|
251 |
+
type=str,
|
252 |
+
choices=["uniform", "performance"],
|
253 |
+
default="uniform",
|
254 |
+
help="How to weight the models (uniform or based on performance)"
|
255 |
+
)
|
256 |
+
|
257 |
+
args = parser.parse_args()
|
258 |
+
|
259 |
+
# Determine weights based on argument
|
260 |
+
weights = None
|
261 |
+
if args.weighting == "performance":
|
262 |
+
# Weights inversely proportional to the MAE (better models get higher weights)
|
263 |
+
# These are the MAE values from the provided results
|
264 |
+
mae_values = [0.3635, 0.3765, 0.3959] # efficientnet_b3+transformer, efficientnet_b0+transformer, resnet50+transformer
|
265 |
+
|
266 |
+
# Convert to weights (inverse of MAE, normalized)
|
267 |
+
inverse_mae = [1/mae for mae in mae_values]
|
268 |
+
total = sum(inverse_mae)
|
269 |
+
weights = [val/total for val in inverse_mae]
|
270 |
+
|
271 |
+
print(f"\033[92mINFO\033[0m: Using performance-based weights: {weights}")
|
272 |
+
else:
|
273 |
+
print(f"\033[92mINFO\033[0m: Using uniform weights")
|
274 |
+
|
275 |
+
# Evaluate the MoE model
|
276 |
+
evaluate_moe_model(args.data_dir, args.model_dir, weights)
|