zklee98 commited on
Commit
97e4a2e
·
verified ·
1 Parent(s): b532cad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import transformers
3
+ from torch import nn
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ # Instantiate classification model
8
+ from fastai.vision.all import *
9
+ model_multi = load_learner('vit_tiny_patch16.pkl')
10
+
11
+ def binary_label(path):
12
+ return 'No-anomaly' if (parent_label(path) == 'No-Anomaly') else 'Anomaly'
13
+
14
+ model_binary = load_learner('vit_tiny_patch16_binary.pkl')
15
+
16
+ # Instantiate segmentation model
17
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
18
+ from torchvision.transforms import Grayscale
19
+
20
+ seg_feature_extractor = SegformerFeatureExtractor.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
21
+ seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
22
+
23
+ def get_seg_overlay(image, seg):
24
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
25
+ palette = np.array(sidewalk_palette())
26
+ for label, color in enumerate(palette):
27
+ color_seg[seg == label, :] = color
28
+
29
+ # Show image + mask
30
+ img = np.array(image) * 0.5 + color_seg * 0.5
31
+ img = img.astype(np.uint8)
32
+ #img = PIL.Image.open(img)
33
+
34
+ return img
35
+
36
+ #@title `def sidewalk_palette()`
37
+
38
+ def sidewalk_palette():
39
+ """Sidewalk palette that maps each class to RGB values."""
40
+ return [
41
+ [0, 0, 0],
42
+ [216, 82, 24],
43
+ [255, 255, 0],
44
+ [125, 46, 141],
45
+ [118, 171, 47],
46
+ [161, 19, 46],
47
+ [255, 0, 0],
48
+ [0, 128, 128],
49
+ [190, 190, 0],
50
+ [0, 255, 0],
51
+ [0, 0, 255],
52
+ [170, 0, 255],
53
+ [84, 84, 0],
54
+ [84, 170, 0],
55
+ [84, 255, 0],
56
+ [170, 84, 0],
57
+ [170, 170, 0],
58
+ [170, 255, 0],
59
+ [255, 84, 0],
60
+ [255, 170, 0],
61
+ [255, 255, 0],
62
+ [33, 138, 200],
63
+ [0, 170, 127],
64
+ [0, 255, 127],
65
+ [84, 0, 127],
66
+ [84, 84, 127],
67
+ [84, 170, 127],
68
+ [84, 255, 127],
69
+ [170, 0, 127],
70
+ [170, 84, 127],
71
+ [170, 170, 127],
72
+ [170, 255, 127],
73
+ [255, 0, 127],
74
+ [255, 84, 127],
75
+ [255, 170, 127],
76
+ ]
77
+
78
+
79
+
80
+ def predict(classification_mode, image):
81
+
82
+ if (classification_mode == 'Binary Classification'):
83
+ model = model_binary
84
+ else:
85
+ model = model_multi
86
+
87
+ labels = model.dls.vocab
88
+ # Classification model prediction
89
+ pred, pred_idx, probs = model.predict(image)
90
+
91
+ seg_img = None
92
+ percentage_affected = '0%'
93
+ if (pred.upper() != 'NO-ANOMALY'):
94
+ addChannel = Grayscale(num_output_channels=3)
95
+ image = addChannel(image)
96
+
97
+ inputs = seg_feature_extractor(images=image, return_tensors="pt")
98
+ outputs = seg_model(**inputs)
99
+ logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
100
+
101
+ # First, rescale logits to original image size
102
+ upsampled_logits = nn.functional.interpolate(
103
+ logits,
104
+ size=image.size[::-1], # (height, width)
105
+ mode='bilinear',
106
+ align_corners=False
107
+ )
108
+
109
+ # Second, apply argmax on the class dimension
110
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
111
+
112
+ seg_img = get_seg_overlay(image, pred_seg)
113
+
114
+ classified_pixels = np.unique(pred_seg.numpy(), return_counts=True)
115
+ pixels_count = dict({classified_pixels[0][0]: classified_pixels[1][0],
116
+ classified_pixels[0][1]: classified_pixels[1][1]})
117
+ percentage_affected = round((pixels_count[1]/960)*100, 1)
118
+ percentage_affected = str(percentage_affected) + '%'
119
+
120
+ #seg_img = PIL.Image.fromarray(seg_img)
121
+
122
+ return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
123
+
124
+
125
+ description = """
126
+ <center><img src="https://i0.wp.com/mapperx.com/wp-content/uploads/2023/01/Termal-Drone-Ile-Pv-Panel-Inceleme.jpg?w=1600&ssl=1" width=270px></center><br><br><br><br>
127
+ <center>This program identifies the type of anomaly found in solar panel using an image classification model and the percentage of the affected area using an image segmentation model.</center>
128
+ <center><i>(Models are trained on <a href="https://ai4earthscience.github.io/iclr-2020-workshop/papers/ai4earth22.pdf">InfraredSolarModules</a> dataset, and hence expect infrared image as input)</center></i>
129
+ """
130
+
131
+ gr.Interface(fn=predict,
132
+ inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
133
+ info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
134
+ gr.Image(type='pil', label='Input infrared image: ')],
135
+ outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
136
+ gr.Image(type='pil', label=' ').style(height=240, width=144),
137
+ gr.Textbox(label='Affected area:').style(container=False)],
138
+ title='Solar Panel Anomaly Detector',
139
+ description=description,
140
+ examples=[[], []],
141
+ article= '<center>by <a href="https://www.linkedin.com/in/lzk/">Lee Zhe Kaai</a></center>').launch()