import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image from ResNet_for_CC import CC_model # Import updated model # Set device (CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the trained CC_model model_path = "CC_net.pt" # Ensure correct path model = CC_model(num_classes1=14) # Updated model with classification model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() # Define Clothing1M Class Labels class_labels = [ "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", "Vest", "Underwear" ] # Define preprocessing for images transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Function for Image Classification def classify_image(image): image = transform(image).unsqueeze(0).to(device) # Preprocess image with torch.no_grad(): _, output = model(image) # Unpack to get only output_mean predicted_class = torch.argmax(output, dim=1).item() # Get class index return f"Predicted Class: {class_labels[predicted_class]}" # Create Gradio Interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs="text", title="Clothing1M Image Classifier", description="Upload a clothing image, and the model will classify it into one of the 14 categories." ) # Run the Interface if __name__ == "__main__": interface.launch()