zklee98 commited on
Commit
d758068
·
verified ·
1 Parent(s): fedae66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -43
app.py CHANGED
@@ -10,7 +10,7 @@ from fastai.vision.all import *
10
  model_multi = load_learner('vit_tiny_patch16.pkl')
11
 
12
  def binary_label(path):
13
- return 'No-anomaly' if (parent_label(path) == 'No-Anomaly') else 'Anomaly'
14
 
15
  model_binary = load_learner('vit_tiny_patch16_binary.pkl')
16
 
@@ -22,17 +22,17 @@ seg_feature_extractor = SegformerFeatureExtractor.from_pretrained('zklee98/segfo
22
  seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
23
 
24
  def get_seg_overlay(image, seg):
25
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
26
- palette = np.array(sidewalk_palette())
27
- for label, color in enumerate(palette):
28
- color_seg[seg == label, :] = color
29
 
30
- # Show image + mask
31
- img = np.array(image) * 0.5 + color_seg * 0.5
32
- img = img.astype(np.uint8)
33
- #img = PIL.Image.open(img)
34
 
35
- return img
36
 
37
  #@title `def sidewalk_palette()`
38
 
@@ -80,46 +80,45 @@ def sidewalk_palette():
80
 
81
  def predict(classification_mode, image):
82
 
83
- if (classification_mode == 'Binary Classification'):
84
- model = model_binary
85
- else:
86
- model = model_multi
87
 
88
- labels = model.dls.vocab
89
- # Classification model prediction
90
- image = PILImage.create(image)
91
- pred, pred_idx, probs = model.predict(image)
92
 
93
- seg_img = None
94
- percentage_affected = '0%'
95
- if (pred.upper() != 'NO-ANOMALY'):
96
- addChannel = Grayscale(num_output_channels=3)
97
- image = addChannel(image)
98
 
99
- inputs = seg_feature_extractor(images=image, return_tensors="pt")
100
- outputs = seg_model(**inputs)
101
- logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
102
 
103
- # First, rescale logits to original image size
104
- upsampled_logits = nn.functional.interpolate(
105
- logits,
106
- size=image.size[::-1], # (height, width)
107
- mode='bilinear',
108
- align_corners=False
109
- )
110
 
111
- # Second, apply argmax on the class dimension
112
- pred_seg = upsampled_logits.argmax(dim=1)[0]
113
 
114
- seg_img = get_seg_overlay(image, pred_seg)
115
 
116
- classified_pixels = np.unique(pred_seg.numpy(), return_counts=True)
117
- pixels_count = dict({classified_pixels[0][0]: classified_pixels[1][0],
118
- classified_pixels[0][1]: classified_pixels[1][1]})
119
- percentage_affected = round((pixels_count[1]/960)*100, 1)
120
- percentage_affected = str(percentage_affected) + '%'
121
 
122
- seg_img = PIL.Image.fromarray(seg_img)
123
 
124
  return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
125
 
@@ -133,7 +132,7 @@ description = """
133
  gr.Interface(fn=predict,
134
  inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
135
  info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
136
- gr.Image(label='Input infrared image: ')],
137
  outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
138
  gr.Image(type='pil', label=' ').style(height=240, width=144),
139
  gr.Textbox(label='Affected area:').style(container=False)],
 
10
  model_multi = load_learner('vit_tiny_patch16.pkl')
11
 
12
  def binary_label(path):
13
+ return 'No-anomaly' if (parent_label(path) == 'No-Anomaly') else 'Anomaly'
14
 
15
  model_binary = load_learner('vit_tiny_patch16_binary.pkl')
16
 
 
22
  seg_model = SegformerForSemanticSegmentation.from_pretrained('zklee98/segformer-b1-solarModuleAnomaly-v0.1')
23
 
24
  def get_seg_overlay(image, seg):
25
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
26
+ palette = np.array(sidewalk_palette())
27
+ for label, color in enumerate(palette):
28
+ color_seg[seg == label, :] = color
29
 
30
+ # Show image + mask
31
+ img = np.array(image) * 0.5 + color_seg * 0.5
32
+ img = img.astype(np.uint8)
33
+ #img = PIL.Image.open(img)
34
 
35
+ return img
36
 
37
  #@title `def sidewalk_palette()`
38
 
 
80
 
81
  def predict(classification_mode, image):
82
 
83
+ if (classification_mode == 'Binary Classification'):
84
+ model = model_binary
85
+ else:
86
+ model = model_multi
87
 
88
+ labels = model.dls.vocab
89
+ # Classification model prediction
90
+ #image = PILImage.create(image)
91
+ pred, pred_idx, probs = model.predict(image)
92
 
93
+ seg_img = None
94
+ percentage_affected = '0%'
95
+ if (pred.upper() != 'NO-ANOMALY'):
96
+ addChannel = Grayscale(num_output_channels=3)
97
+ image = addChannel(image)
98
 
99
+ inputs = seg_feature_extractor(images=image, return_tensors="pt")
100
+ outputs = seg_model(**inputs)
101
+ logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
102
 
103
+ # First, rescale logits to original image size
104
+ upsampled_logits = nn.functional.interpolate(
105
+ logits,
106
+ size=image.size[::-1], # (height, width)
107
+ mode='bilinear',
108
+ align_corners=False)
 
109
 
110
+ # Second, apply argmax on the class dimension
111
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
112
 
113
+ seg_img = get_seg_overlay(image, pred_seg)
114
 
115
+ classified_pixels = np.unique(pred_seg.numpy(), return_counts=True)
116
+ pixels_count = dict({classified_pixels[0][0]: classified_pixels[1][0],
117
+ classified_pixels[0][1]: classified_pixels[1][1]})
118
+ percentage_affected = round((pixels_count[1]/960)*100, 1)
119
+ percentage_affected = str(percentage_affected) + '%'
120
 
121
+ seg_img = PIL.Image.fromarray(seg_img)
122
 
123
  return ({labels[i]: float(probs[i]) for i in range(len(labels))}, seg_img, percentage_affected)
124
 
 
132
  gr.Interface(fn=predict,
133
  inputs= [gr.Dropdown(choices=['Binary Classification', 'Multiclass Classification'], label='Classification Mode:',
134
  info='Choose to classify between anomaly and no-anomaly OR between 12 different types of anomalies.'),
135
+ gr.Image(type='pil', label='Input infrared image: ')],
136
  outputs=[gr.outputs.Label(num_top_classes=3, label='Detected:').style(container=False),
137
  gr.Image(type='pil', label=' ').style(height=240, width=144),
138
  gr.Textbox(label='Affected area:').style(container=False)],