will33am commited on
Commit
01b0a68
1 Parent(s): f25cf91
Files changed (2) hide show
  1. app.py +76 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[ ]:
5
+
6
+
7
+ import albumentations as A
8
+ from albumentations.pytorch.transforms import ToTensorV2
9
+ from timm import create_model
10
+ import torch
11
+ import gradio as gr
12
+
13
+
14
+ # In[ ]:
15
+
16
+
17
+ class TestDataset(torch.utils.data.Dataset):
18
+ def __init__(self,image,transforms = None):
19
+ self.image = [image]
20
+ self.transforms = transforms
21
+
22
+ def __getitem__(self,idx):
23
+ image = self.image[idx]
24
+ if self.transforms:
25
+ augmented = self.transforms(image=image)
26
+ image = augmented["image"]
27
+ return {'image':image}
28
+ def __len__(self):
29
+ return len(self.image)
30
+
31
+ def get_test_transform():
32
+ MEAN = [0.5176, 0.4169, 0.3637]
33
+ STD = [0.3010, 0.2723, 0.2672]
34
+ return A.Compose([
35
+ #A.resize((256,256)),
36
+ A.Normalize(MEAN,STD),
37
+ ToTensorV2(transpose_mask=False,p=1.0)
38
+ ])
39
+
40
+
41
+ # In[ ]:
42
+
43
+
44
+ def predict_image(image):
45
+ test_dataset = TestDataset(image,transforms = get_test_transform())
46
+ test_loader = torch.utils.data.DataLoader(test_dataset,
47
+ batch_size = 1,
48
+ pin_memory = False,
49
+ num_workers = 8,
50
+ shuffle = False)
51
+ # Loading weights
52
+ for data in test_loader:
53
+ for key,value in data.items():
54
+ data[key] = value.to('cpu')
55
+ # Appending Output and Targets:
56
+ output = torch.sigmoid(model(data['image'])).cpu().detach().numpy()
57
+ dict_ = {'Down':float(1-output[0][0]),'Upside':float(output[0][0])}
58
+ return dict_
59
+
60
+
61
+ # In[ ]:
62
+
63
+
64
+ model = create_model('resnet18',pretrained = False,num_classes = 1)
65
+ checkpoint = torch.load('model.pt',map_location = 'cpu')
66
+ model.load_state_dict(checkpoint,strict = False)
67
+
68
+
69
+ # In[ ]:
70
+
71
+
72
+ title = "Upside-Down Detector"
73
+ interpretation='default'
74
+ enable_queue=True
75
+ gr.Interface(fn=predict_image,inputs=gr.inputs.Image(shape=(256, 256)),outputs=gr.outputs.Label(num_top_classes=2),title=title,interpretation=interpretation).launch(share = True)
76
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ albumentations
2
+ timm