Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ from fastai.vision.all import *
|
|
| 10 |
model_multi = load_learner('vit_tiny_patch16.pkl')
|
| 11 |
|
| 12 |
def binary_label(path):
|
| 13 |
-
|
| 14 |
|
| 15 |
model_binary = load_learner('vit_tiny_patch16_binary.pkl')
|
| 16 |
|
|
@@ -22,17 +22,17 @@ seg_feature_extractor = SegformerFeatureExtractor.from_pretrained('zklee98/segfo
|
|
| 22 |
seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
|
| 23 |
|
| 24 |
def get_seg_overlay(image, seg):
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
|
| 37 |
#@title `def sidewalk_palette()`
|
| 38 |
|
|
@@ -80,46 +80,45 @@ def sidewalk_palette():
|
|
| 80 |
|
| 81 |
def predict(classification_mode, image):
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
|
| 114 |
-
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
|
| 122 |
-
|
| 123 |
|
| 124 |
return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
|
| 125 |
|
|
@@ -133,7 +132,7 @@ description = """
|
|
| 133 |
gr.Interface(fn=predict,
|
| 134 |
inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
|
| 135 |
info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
|
| 136 |
-
gr.Image(label='Input infrared image: ')],
|
| 137 |
outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
|
| 138 |
gr.Image(type='pil', label=' ').style(height=240, width=144),
|
| 139 |
gr.Textbox(label='Affected area:').style(container=False)],
|
|
|
|
| 10 |
model_multi = load_learner('vit_tiny_patch16.pkl')
|
| 11 |
|
| 12 |
def binary_label(path):
|
| 13 |
+
return 'No-anomaly' if (parent_label(path) == 'No-Anomaly') else 'Anomaly'
|
| 14 |
|
| 15 |
model_binary = load_learner('vit_tiny_patch16_binary.pkl')
|
| 16 |
|
|
|
|
| 22 |
seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
|
| 23 |
|
| 24 |
def get_seg_overlay(image, seg):
|
| 25 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
|
| 26 |
+
palette = np.array(sidewalk_palette())
|
| 27 |
+
for label, color in enumerate(palette):
|
| 28 |
+
color_seg[seg == label, :] = color
|
| 29 |
|
| 30 |
+
# Show image + mask
|
| 31 |
+
img = np.array(image) * 0.5 + color_seg * 0.5
|
| 32 |
+
img = img.astype(np.uint8)
|
| 33 |
+
#img = PIL.Image.open(img)
|
| 34 |
|
| 35 |
+
return img
|
| 36 |
|
| 37 |
#@title `def sidewalk_palette()`
|
| 38 |
|
|
|
|
| 80 |
|
| 81 |
def predict(classification_mode, image):
|
| 82 |
|
| 83 |
+
if (classification_mode == 'Binary Classification'):
|
| 84 |
+
model = model_binary
|
| 85 |
+
else:
|
| 86 |
+
model = model_multi
|
| 87 |
|
| 88 |
+
labels = model.dls.vocab
|
| 89 |
+
# Classification model prediction
|
| 90 |
+
#image = PILImage.create(image)
|
| 91 |
+
pred, pred_idx, probs = model.predict(image)
|
| 92 |
|
| 93 |
+
seg_img = None
|
| 94 |
+
percentage_affected = '0%'
|
| 95 |
+
if (pred.upper() != 'NO-ANOMALY'):
|
| 96 |
+
addChannel = Grayscale(num_output_channels=3)
|
| 97 |
+
image = addChannel(image)
|
| 98 |
|
| 99 |
+
inputs = seg_feature_extractor(images=image, return_tensors="pt")
|
| 100 |
+
outputs = seg_model(**inputs)
|
| 101 |
+
logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
|
| 102 |
|
| 103 |
+
# First, rescale logits to original image size
|
| 104 |
+
upsampled_logits = nn.functional.interpolate(
|
| 105 |
+
logits,
|
| 106 |
+
size=image.size[::-1], # (height, width)
|
| 107 |
+
mode='bilinear',
|
| 108 |
+
align_corners=False)
|
|
|
|
| 109 |
|
| 110 |
+
# Second, apply argmax on the class dimension
|
| 111 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
| 112 |
|
| 113 |
+
seg_img = get_seg_overlay(image, pred_seg)
|
| 114 |
|
| 115 |
+
classified_pixels = np.unique(pred_seg.numpy(), return_counts=True)
|
| 116 |
+
pixels_count = dict({classified_pixels[0][0]: classified_pixels[1][0],
|
| 117 |
+
classified_pixels[0][1]: classified_pixels[1][1]})
|
| 118 |
+
percentage_affected = round((pixels_count[1]/960)*100, 1)
|
| 119 |
+
percentage_affected = str(percentage_affected) + '%'
|
| 120 |
|
| 121 |
+
seg_img = PIL.Image.fromarray(seg_img)
|
| 122 |
|
| 123 |
return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
|
| 124 |
|
|
|
|
| 132 |
gr.Interface(fn=predict,
|
| 133 |
inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
|
| 134 |
info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
|
| 135 |
+
gr.Image(type='pil', label='Input infrared image: ')],
|
| 136 |
outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
|
| 137 |
gr.Image(type='pil', label=' ').style(height=240, width=144),
|
| 138 |
gr.Textbox(label='Affected area:').style(container=False)],
|