peeyushsinghal's picture
changing model path
0d8f893 verified
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import gradio as gr
from model import ResNet50 # Import the model definition
# Load class index mapping
def load_class_labels(label_path):
with open(label_path, 'r') as f:
class_labels = json.load(f)
return class_labels
# Load the model
def load_model(model_path):
model = ResNet50(num_classes=100) # Create an instance of the ResNet-50 model
checkpoint = torch.load(model_path, map_location='cpu') # Load the state dictionary
state_dict = checkpoint['model_state_dict']
# state_dict = torch.load(model_path, map_location='cpu') # Load the state dictionary
# Remove "module." prefix if present
if 'module.' in next(iter(state_dict.keys())):
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict) # Load the state dictionary
model.eval() # Set the model to evaluation mode
return model
# Preprocess the input image
def preprocess_image(image):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = image.convert("RGB") # Ensure image is in RGB format
image = preprocess(image) # Apply transformations
image = image.unsqueeze(0) # Add batch dimension
return image
# Prediction function
def predict(image):
image_tensor = preprocess_image(image)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)
top5_probabilities, top5_indices = probabilities.topk(5)
results = {}
for i in range(5):
class_index = top5_indices[0][i].item()
class_label = class_labels.get(str(class_index), "Unknown class")
results[class_label] = top5_probabilities[0][i].item() # Store label and probability in a dictionary
# results[str(class_index)] = top5_probabilities[0][i].item() # Store label and probability in a dictionary
return results # Return the results as a dictionary
# Load model and class labels
model_path = 'model.pth' # Path to the trained model
label_path = 'subset_imagenet_class_index.json' # Path to the class index mapping
# label_path = 'imagenet_class_index_everything.json' # Path to the class index mapping
model = load_model(model_path)
class_labels = load_class_labels(label_path)
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="Image Classification using ResNet-50 Model",
description="Upload an image to get the top-5 predictions from the ResNet-50 model."
)
# Launch the app
iface.launch()