etweedy commited on
Commit
fac144d
·
1 Parent(s): 1aab986

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +41 -0
  2. mnist2.pkl +3 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import gradio as gr
4
+
5
+ class CNN(nn.Module):
6
+ def __init__(self):
7
+ super(CNN,self).__init__()
8
+
9
+ self.conv1 = nn.Sequential(
10
+ nn.Conv2d(1,16,5,stride=1,padding=2),
11
+ nn.ReLU(),
12
+ nn.MaxPool2d(kernel_size=2),
13
+ )
14
+ self.conv2 = nn.Sequential(
15
+ nn.Conv2d(16,32,5,1,2),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2),
18
+ )
19
+ self.out = nn.Linear(32*7*7,10)
20
+
21
+ def forward(self,x):
22
+ x=self.conv1(x)
23
+ x=self.conv2(x)
24
+ x = x.view(-1,32*7*7)
25
+ return self.out(x)
26
+
27
+ model = CNN()
28
+ model.load_state_dict(torch.load('mnist2.pkl',map_location=torch.device('cpu')))
29
+ model.eval()
30
+
31
+ def predict(img):
32
+ x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
33
+ with torch.no_grad():
34
+ pred = model(x)[0]
35
+ return int(pred.argmax())
36
+
37
+
38
+ gr.Interface(fn=predict,
39
+ inputs="sketchpad",
40
+ outputs="label",
41
+ ).launch()
mnist2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1449c400c6ae2b041b707b6685f5c71dcf81e4ca9fb511547fa5a23c0552f2d0
3
+ size 117783