Spaces:
Runtime error
Runtime error
add first files
Browse files- app.py +53 -0
- focusondepth/__init__.py +0 -0
- focusondepth/fusion.py +41 -0
- focusondepth/head.py +50 -0
- focusondepth/model_config.py +45 -0
- focusondepth/model_definition.py +68 -0
- focusondepth/reassemble.py +115 -0
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
from transformers import AutoConfig, AutoModel
|
11 |
+
from transformers import AutoModel
|
12 |
+
|
13 |
+
from focusondepth.model_config import FocusOnDepthConfig
|
14 |
+
from focusondepth.model_definition import FocusOnDepth
|
15 |
+
|
16 |
+
AutoConfig.register("focusondepth", FocusOnDepthConfig)
|
17 |
+
AutoModel.register(FocusOnDepthConfig, FocusOnDepth)
|
18 |
+
|
19 |
+
original_image_cache = {}
|
20 |
+
transform = transforms.Compose([
|
21 |
+
transforms.Resize((384, 384)),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
24 |
+
])
|
25 |
+
model = AutoModel.from_pretrained('ybelkada/focusondepth', trust_remote_code=True)
|
26 |
+
model.load_state_dict(torch.load('./focusondepth/FocusOnDepth_vit_base_patch16_384.p', map_location=torch.device('cpu'))['model_state_dict'])
|
27 |
+
|
28 |
+
@torch.no_grad()
|
29 |
+
def inference(input_image):
|
30 |
+
global model, transform
|
31 |
+
|
32 |
+
model.eval()
|
33 |
+
input_image = Image.fromarray(input_image)
|
34 |
+
original_size = input_image.size
|
35 |
+
tensor_image = transform(input_image)
|
36 |
+
|
37 |
+
depth, segmentation = model(tensor_image.unsqueeze(0))
|
38 |
+
depth = 1-depth
|
39 |
+
|
40 |
+
depth = transforms.ToPILImage()(depth[0, :])
|
41 |
+
segmentation = transforms.ToPILImage()(segmentation.argmax(dim=1).float())
|
42 |
+
|
43 |
+
return [depth.resize(original_size, resample=Image.BICUBIC), segmentation.resize(original_size, resample=Image.NEAREST)]
|
44 |
+
|
45 |
+
iface = gr.Interface(
|
46 |
+
fn=inference,
|
47 |
+
inputs=gr.inputs.Image(label="Input Image"),
|
48 |
+
outputs = [
|
49 |
+
gr.outputs.Image(label="Depth Map:"),
|
50 |
+
gr.outputs.Image(label="Segmentation Map:"),
|
51 |
+
],
|
52 |
+
)
|
53 |
+
iface.launch()
|
focusondepth/__init__.py
ADDED
File without changes
|
focusondepth/fusion.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class ResidualConvUnit(nn.Module):
|
6 |
+
def __init__(self, features):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.conv1 = nn.Conv2d(
|
10 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
11 |
+
self.conv2 = nn.Conv2d(
|
12 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
13 |
+
self.relu = nn.ReLU(inplace=True)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
"""Forward pass.
|
17 |
+
Args:
|
18 |
+
x (tensor): input
|
19 |
+
Returns:
|
20 |
+
tensor: output
|
21 |
+
"""
|
22 |
+
out = self.relu(x)
|
23 |
+
out = self.conv1(out)
|
24 |
+
out = self.relu(out)
|
25 |
+
out = self.conv2(out)
|
26 |
+
return out + x
|
27 |
+
|
28 |
+
class Fusion(nn.Module):
|
29 |
+
def __init__(self, resample_dim):
|
30 |
+
super(Fusion, self).__init__()
|
31 |
+
self.res_conv1 = ResidualConvUnit(resample_dim)
|
32 |
+
self.res_conv2 = ResidualConvUnit(resample_dim)
|
33 |
+
|
34 |
+
def forward(self, x, previous_stage=None):
|
35 |
+
if previous_stage == None:
|
36 |
+
previous_stage = torch.zeros_like(x)
|
37 |
+
output_stage1 = self.res_conv1(x)
|
38 |
+
output_stage1 += previous_stage
|
39 |
+
output_stage2 = self.res_conv2(output_stage1)
|
40 |
+
output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True)
|
41 |
+
return output_stage2
|
focusondepth/head.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class Interpolate(nn.Module):
|
6 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
7 |
+
super(Interpolate, self).__init__()
|
8 |
+
self.interp = nn.functional.interpolate
|
9 |
+
self.scale_factor = scale_factor
|
10 |
+
self.mode = mode
|
11 |
+
self.align_corners = align_corners
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
x = self.interp(
|
15 |
+
x,
|
16 |
+
scale_factor=self.scale_factor,
|
17 |
+
mode=self.mode,
|
18 |
+
align_corners=self.align_corners)
|
19 |
+
return x
|
20 |
+
|
21 |
+
class HeadDepth(nn.Module):
|
22 |
+
def __init__(self, features):
|
23 |
+
super(HeadDepth, self).__init__()
|
24 |
+
self.head = nn.Sequential(
|
25 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
26 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
27 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
30 |
+
# nn.ReLU()
|
31 |
+
nn.Sigmoid()
|
32 |
+
)
|
33 |
+
def forward(self, x):
|
34 |
+
x = self.head(x)
|
35 |
+
# x = (x - x.min())/(x.max()-x.min() + 1e-15)
|
36 |
+
return x
|
37 |
+
|
38 |
+
class HeadSeg(nn.Module):
|
39 |
+
def __init__(self, features, nclasses=2):
|
40 |
+
super(HeadSeg, self).__init__()
|
41 |
+
self.head = nn.Sequential(
|
42 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
43 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
44 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
45 |
+
nn.ReLU(),
|
46 |
+
nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0)
|
47 |
+
)
|
48 |
+
def forward(self, x):
|
49 |
+
x = self.head(x)
|
50 |
+
return x
|
focusondepth/model_config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class FocusOnDepthConfig(PretrainedConfig):
|
6 |
+
model_type = "focusondepth"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
image_size = (3, 384, 384),
|
11 |
+
patch_size = 16,
|
12 |
+
emb_dim = 768,
|
13 |
+
resample_dim = 256,
|
14 |
+
read = 'projection',
|
15 |
+
num_layers_encoder = 24,
|
16 |
+
hooks = [2, 5, 8, 11],
|
17 |
+
reassemble_s = [4, 8, 16, 32],
|
18 |
+
transformer_dropout= 0,
|
19 |
+
nclasses = 2,
|
20 |
+
type_ = "full",
|
21 |
+
model_timm = "vit_base_patch16_384",
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
+
if type_ not in ["full", "depth", "segmentation"]:
|
25 |
+
raise ValueError(f"`type_` must be 'full' or depth' or 'segmentation, got {type_}.")
|
26 |
+
if read not in ["ignore", "add", "projection"]:
|
27 |
+
raise ValueError(f"`read` must be '', 'ignore' or 'add' or 'projection, got {read}.")
|
28 |
+
if image_size[0] != 3 and image_size[1] != 384 and image_size[2] != 384:
|
29 |
+
raise ValueError(f"`image_size` must be 3, 384, 384, got {image_size}.")
|
30 |
+
if patch_size != 16:
|
31 |
+
raise ValueError(f"`patch_size` must be 16, got {patch_size}.")
|
32 |
+
|
33 |
+
self.image_size = image_size
|
34 |
+
self.patch_size = patch_size
|
35 |
+
self.emb_dim = emb_dim
|
36 |
+
self.resample_dim = resample_dim
|
37 |
+
self.read = read
|
38 |
+
self.num_layers_encoder = num_layers_encoder
|
39 |
+
self.hooks = hooks
|
40 |
+
self.reassemble_s = reassemble_s
|
41 |
+
self.transformer_dropout = transformer_dropout
|
42 |
+
self.nclasses = nclasses
|
43 |
+
self.type_ = type_
|
44 |
+
self.model_timm = model_timm
|
45 |
+
super().__init__(**kwargs)
|
focusondepth/model_definition.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
import timm
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .model_config import FocusOnDepthConfig
|
7 |
+
from .reassemble import Reassemble
|
8 |
+
from .fusion import Fusion
|
9 |
+
from .head import HeadDepth, HeadSeg
|
10 |
+
|
11 |
+
|
12 |
+
class FocusOnDepth(PreTrainedModel):
|
13 |
+
config_class = FocusOnDepthConfig
|
14 |
+
|
15 |
+
def __init__(self, config):
|
16 |
+
super().__init__(config)
|
17 |
+
self.transformer_encoders = timm.create_model(config.model_timm, pretrained=True)
|
18 |
+
self.type_ = config.type_
|
19 |
+
|
20 |
+
#Register hooks
|
21 |
+
self.activation = {}
|
22 |
+
self.hooks = config.hooks
|
23 |
+
self._get_layers_from_hooks(self.hooks)
|
24 |
+
|
25 |
+
#Reassembles Fusion
|
26 |
+
self.reassembles = []
|
27 |
+
self.fusions = []
|
28 |
+
for s in config.reassemble_s:
|
29 |
+
self.reassembles.append(Reassemble(config.image_size, config.read, config.patch_size, s, config.emb_dim, config.resample_dim))
|
30 |
+
self.fusions.append(Fusion(config.resample_dim))
|
31 |
+
self.reassembles = nn.ModuleList(self.reassembles)
|
32 |
+
self.fusions = nn.ModuleList(self.fusions)
|
33 |
+
|
34 |
+
#Head
|
35 |
+
if self.type_ == "full":
|
36 |
+
self.head_depth = HeadDepth(config.resample_dim)
|
37 |
+
self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
|
38 |
+
elif self.type_ == "depth":
|
39 |
+
self.head_depth = HeadDepth(config.resample_dim)
|
40 |
+
self.head_segmentation = None
|
41 |
+
else:
|
42 |
+
self.head_depth = None
|
43 |
+
self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses)
|
44 |
+
|
45 |
+
def forward(self, img):
|
46 |
+
_ = self.transformer_encoders(img)
|
47 |
+
previous_stage = None
|
48 |
+
for i in np.arange(len(self.fusions)-1, -1, -1):
|
49 |
+
hook_to_take = 't'+str(self.hooks[i])
|
50 |
+
activation_result = self.activation[hook_to_take]
|
51 |
+
reassemble_result = self.reassembles[i](activation_result)
|
52 |
+
fusion_result = self.fusions[i](reassemble_result, previous_stage)
|
53 |
+
previous_stage = fusion_result
|
54 |
+
out_depth = None
|
55 |
+
out_segmentation = None
|
56 |
+
if self.head_depth != None:
|
57 |
+
out_depth = self.head_depth(previous_stage)
|
58 |
+
if self.head_segmentation != None:
|
59 |
+
out_segmentation = self.head_segmentation(previous_stage)
|
60 |
+
return out_depth, out_segmentation
|
61 |
+
|
62 |
+
def _get_layers_from_hooks(self, hooks):
|
63 |
+
def get_activation(name):
|
64 |
+
def hook(model, input, output):
|
65 |
+
self.activation[name] = output
|
66 |
+
return hook
|
67 |
+
for h in hooks:
|
68 |
+
self.transformer_encoders.blocks[h].register_forward_hook(get_activation('t'+str(h)))
|
focusondepth/reassemble.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
|
7 |
+
class Read_ignore(nn.Module):
|
8 |
+
def __init__(self, start_index=1):
|
9 |
+
super(Read_ignore, self).__init__()
|
10 |
+
self.start_index = start_index
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
return x[:, self.start_index:]
|
14 |
+
|
15 |
+
|
16 |
+
class Read_add(nn.Module):
|
17 |
+
def __init__(self, start_index=1):
|
18 |
+
super(Read_add, self).__init__()
|
19 |
+
self.start_index = start_index
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
if self.start_index == 2:
|
23 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
24 |
+
else:
|
25 |
+
readout = x[:, 0]
|
26 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
27 |
+
|
28 |
+
|
29 |
+
class Read_projection(nn.Module):
|
30 |
+
def __init__(self, in_features, start_index=1):
|
31 |
+
super(Read_projection, self).__init__()
|
32 |
+
self.start_index = start_index
|
33 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
37 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
38 |
+
return self.project(features)
|
39 |
+
|
40 |
+
class MyConvTranspose2d(nn.Module):
|
41 |
+
def __init__(self, conv, output_size):
|
42 |
+
super(MyConvTranspose2d, self).__init__()
|
43 |
+
self.output_size = output_size
|
44 |
+
self.conv = conv
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.conv(x, output_size=self.output_size)
|
48 |
+
return x
|
49 |
+
|
50 |
+
class Resample(nn.Module):
|
51 |
+
def __init__(self, p, s, h, emb_dim, resample_dim):
|
52 |
+
super(Resample, self).__init__()
|
53 |
+
assert (s in [4, 8, 16, 32]), "s must be in [0.5, 4, 8, 16, 32]"
|
54 |
+
self.conv1 = nn.Conv2d(emb_dim, resample_dim, kernel_size=1, stride=1, padding=0)
|
55 |
+
if s == 4:
|
56 |
+
self.conv2 = nn.ConvTranspose2d(resample_dim,
|
57 |
+
resample_dim,
|
58 |
+
kernel_size=4,
|
59 |
+
stride=4,
|
60 |
+
padding=0,
|
61 |
+
bias=True,
|
62 |
+
dilation=1,
|
63 |
+
groups=1)
|
64 |
+
elif s == 8:
|
65 |
+
self.conv2 = nn.ConvTranspose2d(resample_dim,
|
66 |
+
resample_dim,
|
67 |
+
kernel_size=2,
|
68 |
+
stride=2,
|
69 |
+
padding=0,
|
70 |
+
bias=True,
|
71 |
+
dilation=1,
|
72 |
+
groups=1)
|
73 |
+
elif s == 16:
|
74 |
+
self.conv2 = nn.Identity()
|
75 |
+
else:
|
76 |
+
self.conv2 = nn.Conv2d(resample_dim, resample_dim, kernel_size=2,stride=2, padding=0, bias=True)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = self.conv1(x)
|
80 |
+
x = self.conv2(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
class Reassemble(nn.Module):
|
84 |
+
def __init__(self, image_size, read, p, s, emb_dim, resample_dim):
|
85 |
+
"""
|
86 |
+
p = patch size
|
87 |
+
s = coefficient resample
|
88 |
+
emb_dim <=> D (in the paper)
|
89 |
+
resample_dim <=> ^D (in the paper)
|
90 |
+
read : {"ignore", "add", "projection"}
|
91 |
+
"""
|
92 |
+
super(Reassemble, self).__init__()
|
93 |
+
channels, image_height, image_width = image_size
|
94 |
+
|
95 |
+
#Read
|
96 |
+
self.read = Read_ignore()
|
97 |
+
if read == 'add':
|
98 |
+
self.read = Read_add()
|
99 |
+
elif read == 'projection':
|
100 |
+
self.read = Read_projection(emb_dim)
|
101 |
+
|
102 |
+
#Concat after read
|
103 |
+
self.concat = Rearrange('b (h w) c -> b c h w',
|
104 |
+
c=emb_dim,
|
105 |
+
h=(image_height // p),
|
106 |
+
w=(image_width // p))
|
107 |
+
|
108 |
+
#Projection + Resample
|
109 |
+
self.resample = Resample(p, s, image_height, emb_dim, resample_dim)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.read(x)
|
113 |
+
x = self.concat(x)
|
114 |
+
x = self.resample(x)
|
115 |
+
return x
|