import gradio as gr import os import torch from model import create_effnetb2_model from timeit import default_timer as timer from typing import Tuple, Dict import pkg_resources import logging # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Check Gradio version try: gradio_version = pkg_resources.get_distribution("gradio").version logger.info(f"Using Gradio version: {gradio_version}") except pkg_resources.DistributionNotFound: raise ImportError("Gradio is not installed. Please install it using 'pip install gradio'.") # Load class names try: with open("class_names.txt", "r") as f: class_names = [food_name.strip() for food_name in f.readlines()] logger.info("Class names loaded successfully") except FileNotFoundError: logger.error("class_names.txt not found") raise FileNotFoundError("class_names.txt not found.") # Model and transforms preparation try: effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101) logger.info("EfficientNetB2 model created successfully") except Exception as e: logger.error(f"Error creating model: {str(e)}") raise Exception(f"Error creating model: {str(e)}") # Load weights try: effnetb2.load_state_dict( torch.load( "09_pretrained_effnetb2_feature_extractor_food101.pth", map_location=torch.device("cpu"), ) ) logger.info("Model weights loaded successfully") except FileNotFoundError: logger.error("Model weights file not found") raise FileNotFoundError("Model weights file not found.") except Exception as e: logger.error(f"Error loading weights: {str(e)}") raise Exception(f"Error loading weights: {str(e)}") # Predict function def predict(img) -> Tuple[Dict, float]: try: start_time = timer() if img is None: raise ValueError("Input image is None.") img = effnetb2_transforms(img).unsqueeze(0) effnetb2.eval() with torch.inference_mode(): pred_probs = torch.softmax(effnetb2(img), dim=1) pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} pred_time = round(timer() - start_time, 5) logger.info(f"Prediction completed: {pred_labels_and_probs}, Time: {pred_time}") return pred_labels_and_probs, pred_time except Exception as e: logger.error(f"Prediction failed: {str(e)}") return {"error": f"Prediction failed: {str(e)}"}, 0.0 # Gradio app title = "FoodVision 101 🍔👁" description = "An EfficientNetB2 feature extractor to classify 101 food classes." try: example_list = [["examples/" + example] for example in os.listdir("examples")] logger.info("Examples loaded successfully") except FileNotFoundError: example_list = [] logger.warning("'examples/' directory not found") # Simplified Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Label(num_top_classes=5, label="Predictions"), gr.Number(label="Prediction time (s)"), ], examples=example_list, title=title, description=description, allow_flagging="never", # Disable flagging to simplify API api_mode=False, # Disable API mode to avoid schema generation ) # Launch with share=True for Hugging Face Spaces try: demo.launch(share=True) logger.info("Gradio app launched successfully") except Exception as e: logger.error(f"Failed to launch Gradio app: {str(e)}") raise Exception(f"Failed to launch Gradio app: {str(e)}")