s3nh commited on
Commit
54d8be6
·
1 Parent(s): da936cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -26,13 +26,13 @@ def inference(image, chosen_model):
26
  logits = outputs.logits
27
 
28
  output = torch.sigmoid(logits).detach().numpy()[0]
29
- output = np.transpose(output, (1,2,0))
30
- upsampled_logits = nn.functional.interpolate(logits,
31
- size=image.size[::-1], # (height, width)
32
- mode='bilinear',
33
- align_corners=False)
34
 
35
- seg = upsampled_logits.argmax(dim=1)[0]
36
  color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
37
  palette = np.array([[0, 0, 0],[255, 255, 255]])
38
  for label, color in enumerate(palette):
 
26
  logits = outputs.logits
27
 
28
  output = torch.sigmoid(logits).detach().numpy()[0]
29
+ # output = np.transpose(output, (1,2,0))
30
+ # upsampled_logits = nn.functional.interpolate(logits,
31
+ # size=image.size[::-1], # (height, width)
32
+ # mode='bilinear',
33
+ # align_corners=False)
34
 
35
+ seg = output
36
  color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
37
  palette = np.array([[0, 0, 0],[255, 255, 255]])
38
  for label, color in enumerate(palette):