sanket03 commited on
Commit
76c88e2
·
1 Parent(s): 8e4edc6

updated code for top classes in inference code

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -32,7 +32,12 @@ def inference(input_img, transparency = 0.5, target_layer_number = -1, num_top_c
32
  rgb_img = np.transpose(img, (1, 2, 0))
33
  rgb_img = rgb_img.numpy()
34
  visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
35
- return confidences, visualization
 
 
 
 
 
36
 
37
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
38
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
@@ -46,7 +51,7 @@ demo = gr.Interface(
46
  gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"),
47
  gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
48
  gr.Slider(0, 10, value = 1, step=1, label="Number of Top Classes")],
49
- outputs = [gr.Label(num_top_classes=5), gr.Image(shape=(32, 32), label="Output", style={"width": "128px", "height": "128px"})],
50
  title = title,
51
  description = description,
52
  examples = examples,
 
32
  rgb_img = np.transpose(img, (1, 2, 0))
33
  rgb_img = rgb_img.numpy()
34
  visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
35
+
36
+ # Sort confidences dictionary in descending order of values and take top num_top_classes
37
+ sorted_confidences = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)}
38
+ top_classes = list(sorted_confidences.keys())[:num_top_classes]
39
+
40
+ return top_classes, visualization
41
 
42
  title = "CIFAR10 trained on ResNet18 Model with GradCAM"
43
  description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
 
51
  gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"),
52
  gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
53
  gr.Slider(0, 10, value = 1, step=1, label="Number of Top Classes")],
54
+ outputs = [gr.Label(num_top_classes=10), gr.Image(shape=(32, 32), label="Output", style={"width": "128px", "height": "128px"})],
55
  title = title,
56
  description = description,
57
  examples = examples,