Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,7 @@ from fastai.vision.all import *
|
|
10 |
model_multi = load_learner('vit_tiny_patch16.pkl')
|
11 |
|
12 |
def binary_label(path):
|
13 |
-
|
14 |
|
15 |
model_binary = load_learner('vit_tiny_patch16_binary.pkl')
|
16 |
|
@@ -22,17 +22,17 @@ seg_feature_extractor = SegformerFeatureExtractor.from_pretrained('zklee98/segfo
|
|
22 |
seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
|
23 |
|
24 |
def get_seg_overlay(image, seg):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
|
37 |
#@title `def sidewalk_palette()`
|
38 |
|
@@ -80,46 +80,45 @@ def sidewalk_palette():
|
|
80 |
|
81 |
def predict(classification_mode, image):
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
|
114 |
-
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
122 |
-
|
123 |
|
124 |
return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
|
125 |
|
@@ -133,7 +132,7 @@ description = """
|
|
133 |
gr.Interface(fn=predict,
|
134 |
inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
|
135 |
info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
|
136 |
-
gr.Image(label='Input infrared image: ')],
|
137 |
outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
|
138 |
gr.Image(type='pil', label=' ').style(height=240, width=144),
|
139 |
gr.Textbox(label='Affected area:').style(container=False)],
|
|
|
10 |
model_multi = load_learner('vit_tiny_patch16.pkl')
|
11 |
|
12 |
def binary_label(path):
|
13 |
+
return 'No-anomaly' if (parent_label(path) == 'No-Anomaly') else 'Anomaly'
|
14 |
|
15 |
model_binary = load_learner('vit_tiny_patch16_binary.pkl')
|
16 |
|
|
|
22 |
seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
|
23 |
|
24 |
def get_seg_overlay(image, seg):
|
25 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
|
26 |
+
palette = np.array(sidewalk_palette())
|
27 |
+
for label, color in enumerate(palette):
|
28 |
+
color_seg[seg == label, :] = color
|
29 |
|
30 |
+
# Show image + mask
|
31 |
+
img = np.array(image) * 0.5 + color_seg * 0.5
|
32 |
+
img = img.astype(np.uint8)
|
33 |
+
#img = PIL.Image.open(img)
|
34 |
|
35 |
+
return img
|
36 |
|
37 |
#@title `def sidewalk_palette()`
|
38 |
|
|
|
80 |
|
81 |
def predict(classification_mode, image):
|
82 |
|
83 |
+
if (classification_mode == 'Binary Classification'):
|
84 |
+
model = model_binary
|
85 |
+
else:
|
86 |
+
model = model_multi
|
87 |
|
88 |
+
labels = model.dls.vocab
|
89 |
+
# Classification model prediction
|
90 |
+
#image = PILImage.create(image)
|
91 |
+
pred, pred_idx, probs = model.predict(image)
|
92 |
|
93 |
+
seg_img = None
|
94 |
+
percentage_affected = '0%'
|
95 |
+
if (pred.upper() != 'NO-ANOMALY'):
|
96 |
+
addChannel = Grayscale(num_output_channels=3)
|
97 |
+
image = addChannel(image)
|
98 |
|
99 |
+
inputs = seg_feature_extractor(images=image, return_tensors="pt")
|
100 |
+
outputs = seg_model(**inputs)
|
101 |
+
logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
|
102 |
|
103 |
+
# First, rescale logits to original image size
|
104 |
+
upsampled_logits = nn.functional.interpolate(
|
105 |
+
logits,
|
106 |
+
size=image.size[::-1], # (height, width)
|
107 |
+
mode='bilinear',
|
108 |
+
align_corners=False)
|
|
|
109 |
|
110 |
+
# Second, apply argmax on the class dimension
|
111 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
112 |
|
113 |
+
seg_img = get_seg_overlay(image, pred_seg)
|
114 |
|
115 |
+
classified_pixels = np.unique(pred_seg.numpy(), return_counts=True)
|
116 |
+
pixels_count = dict({classified_pixels[0][0]: classified_pixels[1][0],
|
117 |
+
classified_pixels[0][1]: classified_pixels[1][1]})
|
118 |
+
percentage_affected = round((pixels_count[1]/960)*100, 1)
|
119 |
+
percentage_affected = str(percentage_affected) + '%'
|
120 |
|
121 |
+
seg_img = PIL.Image.fromarray(seg_img)
|
122 |
|
123 |
return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
|
124 |
|
|
|
132 |
gr.Interface(fn=predict,
|
133 |
inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
|
134 |
info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
|
135 |
+
gr.Image(type='pil', label='Input infrared image: ')],
|
136 |
outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
|
137 |
gr.Image(type='pil', label=' ').style(height=240, width=144),
|
138 |
gr.Textbox(label='Affected area:').style(container=False)],
|