abhirajeshbhai commited on
Commit
03856d4
·
1 Parent(s): 084800e

loaded new colorizer weights

Browse files
Files changed (3) hide show
  1. app.py +12 -13
  2. model.py +12 -4
  3. unet_colorizer_flickr_5_93_Ploss_10_14K.pth +3 -0
app.py CHANGED
@@ -4,27 +4,26 @@ import torch
4
  import torch.nn as nn
5
 
6
  from PIL import Image
7
- from model import model, image_transforms
8
 
9
 
10
-
11
- def col_select(value):
12
- print(value)
13
-
14
-
15
- st.title("Banan Image Colorizer")
16
 
17
  upload_file = st.file_uploader("Upload Image")
18
 
19
  if upload_file:
 
20
  image = upload_file
 
 
21
  image = Image.open(image)
22
- image_gs = image_transforms(image)
23
- image_gs_prev = image_gs.permute(1, 2, 0).detach().cpu().numpy()
24
-
25
- image_color = model(image_gs.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy()
26
-
27
 
 
 
28
  col1, col2 = st.columns(2)
29
- col1.image(image_gs_prev)
30
  col2.image(image_color, clamp=True, channels='RGB')
 
4
  import torch.nn as nn
5
 
6
  from PIL import Image
7
+ from model import model, image_transforms_gs, image_transforms_rgb
8
 
9
 
10
+ st.title("UNET Image Colorizer")
 
 
 
 
 
11
 
12
  upload_file = st.file_uploader("Upload Image")
13
 
14
  if upload_file:
15
+
16
  image = upload_file
17
+ image_gs = upload_file
18
+
19
  image = Image.open(image)
20
+ if len(np.array(image).shape) < 3:
21
+ image = image_transforms_gs(image)
22
+ else:
23
+ image = image_transforms_rgb(image)
 
24
 
25
+ image_color = model(image.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy()
26
+
27
  col1, col2 = st.columns(2)
28
+ col1.image(image_gs)
29
  col2.image(image_color, clamp=True, channels='RGB')
model.py CHANGED
@@ -8,14 +8,21 @@ import torch.nn.functional as F
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
 
 
 
 
 
11
 
12
- image_transforms = torchvision.transforms.Compose([
13
  torchvision.transforms.Resize((256, 256)),
14
- torchvision.transforms.Grayscale(),
15
  torchvision.transforms.ToTensor(),
16
- torchvision.transforms.Normalize(mean=[0.0], std=[1.0])
17
  ])
18
 
 
 
19
  class ConvBlock(nn.Module):
20
  def __init__(self, in_channel, out_channel):
21
  super(ConvBlock, self).__init__()
@@ -99,5 +106,6 @@ class UNETFruitColor(nn.Module):
99
 
100
 
101
  model = UNETFruitColor()
102
- model.load_state_dict(torch.load("banana_colorizer_unet.pth", map_location=device),strict=True)
 
103
  model.eval()
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ image_transforms_rgb = torchvision.transforms.Compose([
12
+ torchvision.transforms.Resize((256, 256)),
13
+ torchvision.transforms.ToTensor(),
14
+ torchvision.transforms.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]),
15
+ torchvision.transforms.Grayscale()
16
+ ])
17
 
18
+ image_transforms_gs = torchvision.transforms.Compose([
19
  torchvision.transforms.Resize((256, 256)),
 
20
  torchvision.transforms.ToTensor(),
21
+ torchvision.transforms.Normalize(mean=[0.0], std=[1.0]),
22
  ])
23
 
24
+
25
+
26
  class ConvBlock(nn.Module):
27
  def __init__(self, in_channel, out_channel):
28
  super(ConvBlock, self).__init__()
 
106
 
107
 
108
  model = UNETFruitColor()
109
+ model = nn.DataParallel(model).to(device)
110
+ model.load_state_dict(torch.load("unet_colorizer_flickr_5_93_Ploss_10_14K.pth", map_location=device),strict=True)
111
  model.eval()
unet_colorizer_flickr_5_93_Ploss_10_14K.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d03e381cf54c86474608ea5420ad769072fb306257032211261e9a7e3475174
3
+ size 124269794