import torch import torch.nn as nn from efficientnet_pytorch import EfficientNet import gradio as gr # Define the custom model architecture class CustomModel(nn.Module): def __init__(self): super(CustomModel, self).__init__() self.fc = nn.Linear(6, 50176) self.fc_bn = nn.BatchNorm1d(50176) self.pretrained_model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=32) self.classification_head = nn.Sequential( nn.Linear(32, 1), nn.Sigmoid() ) def forward(self, x): x = self.fc(x) x = self.fc_bn(x) x = x.view(-1, 224, 224) x = torch.stack([x] * 3, dim=1) x = self.pretrained_model(x) x = self.classification_head(x) return x # Load the trained model model = CustomModel() model.load_state_dict(torch.load('best_model_efficientnet_b0.pth')) model.eval() # Load the validation dataset #val_dataset = CustomDataset('outside.csv') #val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # Function to make prediction def predict(feature1, feature2, feature3, feature4, feature5, feature6): features = torch.tensor([[feature1, feature2, feature3, feature4, feature5, feature6]], dtype=torch.float32) output = model(features) prediction = output.round().item() return "Kidney Stone Detected" if prediction == 1 else "No Stone Detected" light_blue = "#ADD8E6" # Create a Gradio interface inputs = [ gr.Slider(minimum=0.8, maximum=1.5, label="gravity: Specific Gravity"), # Using gr.Slider for each feature gr.Slider(minimum=3, maximum=8, label="ph: pH (Potential of Hydrogen)"), gr.Slider(minimum=200, maximum=1200, label="osmo: Osmolality"), gr.Slider(minimum=5, maximum=30, label="cond: Conductivity"), gr.Slider(minimum=50, maximum=700, label="urea: Urea"), gr.Slider(minimum=0, maximum=20, label="calc: Calcium") ] output = gr.Label() # Output label for the prediction interface = gr.Interface(predict, inputs, output, title="Kidney Stone Detection NOTE- FOR RESEARCH PURPOSE ONLY-", description="Enter the values for each feature", css=f".gradio-container {{ background-color: {light_blue} }}" # Inline CSS injection ) # Customize interface details interface.launch() # Launch the Gradio interface