Spaces:
Runtime error
Runtime error
TheProjectsGuy
commited on
Commit
·
476e478
1
Parent(s):
ea57705
Pushed the second version of HF App
Browse files- app.py +339 -36
- cache/gem_cache/result_dino_v2.gz +3 -0
- cache/gem_cache/result_dino_v2_tsne.gz +3 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -23,6 +23,34 @@
|
|
23 |
- https://www.gradio.app/guides/blocks-and-event-listeners
|
24 |
"""
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# %%
|
27 |
import os
|
28 |
import gradio as gr
|
@@ -35,12 +63,15 @@ from torchvision import transforms as tvf
|
|
35 |
from torchvision.transforms import functional as T
|
36 |
from PIL import Image
|
37 |
import matplotlib.pyplot as plt
|
|
|
38 |
import distinctipy as dipy
|
|
|
39 |
from typing import Literal, List
|
40 |
import gradio as gr
|
41 |
import time
|
42 |
import glob
|
43 |
import shutil
|
|
|
44 |
from copy import deepcopy
|
45 |
# DINOv2 imports
|
46 |
from utilities import DinoV2ExtractFeatures
|
@@ -62,22 +93,33 @@ desc_facet: T1 = "value"
|
|
62 |
num_c: int = 8
|
63 |
cache_dir: str = _ex("./cache") # Directory containing program cache
|
64 |
max_img_size: int = 1024 # Image resolution (max dim/size)
|
65 |
-
max_num_imgs: int =
|
66 |
share: bool = False # Share application using .gradio link
|
67 |
|
68 |
# Verify inputs
|
69 |
assert os.path.isdir(cache_dir), "Cache directory not found"
|
70 |
|
|
|
71 |
# %%
|
72 |
# Model and transforms
|
73 |
print("Loading DINO model")
|
74 |
-
# extractor = None
|
75 |
extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet,
|
76 |
device=device)
|
77 |
print("DINO model loaded")
|
78 |
# VLAD path (directory)
|
79 |
ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}"
|
80 |
vc_dir = os.path.join(cache_dir, "vocabulary", ext_s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
# Base image transformations
|
82 |
base_tf = tvf.Compose([
|
83 |
tvf.ToTensor(),
|
@@ -115,6 +157,9 @@ def get_descs(imgs_batch, pr = gr.Progress()):
|
|
115 |
pr(0, desc="Extracting descriptors")
|
116 |
patch_descs = []
|
117 |
for i, img in enumerate(imgs_batch):
|
|
|
|
|
|
|
118 |
# Convert to PIL image
|
119 |
pil_img = Image.fromarray(img)
|
120 |
img_pt = base_tf(pil_img).to(device)
|
@@ -139,6 +184,7 @@ def get_descs(imgs_batch, pr = gr.Progress()):
|
|
139 |
ret = extractor(img_pt).cpu() # [1, n_p, d]
|
140 |
patch_descs.append({"img": pil_img, "descs": ret})
|
141 |
pr((i+1) / len(imgs_batch))
|
|
|
142 |
return patch_descs, \
|
143 |
f"Descriptors extracted for {len(imgs_batch)} images"
|
144 |
|
@@ -173,7 +219,10 @@ def assign_vlad(patch_descs, vlad, pr = gr.Progress()):
|
|
173 |
def get_ca_images(desc_assignments, patch_descs, alpha,
|
174 |
pr = gr.Progress()):
|
175 |
if desc_assignments is None or len(desc_assignments) == 0:
|
176 |
-
|
|
|
|
|
|
|
177 |
c_colors = dipy.get_colors(num_c, rng=928,
|
178 |
colorblind_type="Deuteranomaly")
|
179 |
np_colors = (np.array(c_colors) * 255).astype(np.uint8)
|
@@ -202,43 +251,177 @@ def get_ca_images(desc_assignments, patch_descs, alpha,
|
|
202 |
return res_imgs, "Cluster assignment images generated"
|
203 |
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
# %%
|
206 |
print("Interface build started")
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
def var_num_img(s):
|
212 |
-
n = int(s) # Slider value as int
|
213 |
-
return [gr.Image.update(label=f"Image {i+1}", visible=True) \
|
214 |
-
for i in range(n)] + [gr.Image.update(visible=False) \
|
215 |
-
for _ in range(max_num_imgs - n)]
|
216 |
-
|
217 |
-
# ---- State declarations ----
|
218 |
-
vlad = gr.State() # VLAD object
|
219 |
-
desc_assignments = gr.State() # Cluster assignments
|
220 |
-
imgs_batch = gr.State() # Images as batch
|
221 |
-
patch_descs = gr.State() # Patch descriptors
|
222 |
-
|
223 |
-
# ---- All UI elements ----
|
224 |
d_vals = [k.title() for k in DOMAINS]
|
225 |
-
domain = gr.Radio(d_vals, value=d_vals[0]
|
226 |
-
|
227 |
-
|
|
|
|
|
228 |
with gr.Row(): # Dynamic row (images in columns)
|
229 |
imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
|
230 |
-
for i in range(nimg_s.value)] + \
|
231 |
[gr.Image(visible=False) \
|
232 |
-
for _ in range(max_num_imgs - nimg_s.value)]
|
233 |
for i, img in enumerate(imgs): # Set image as "input"
|
234 |
img.change(lambda _: None, img)
|
235 |
with gr.Row(): # Dynamic row of output (cluster) images
|
236 |
imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}",
|
237 |
visible=False) for i in range(max_num_imgs)]
|
238 |
-
nimg_s.
|
239 |
-
blend_alpha = gr.
|
240 |
-
|
|
|
|
|
|
|
|
|
241 |
bttn1 = gr.Button("Click Me!") # Cluster assignment
|
|
|
242 |
out_msg1 = gr.Markdown("Select domain and upload images")
|
243 |
out_msg2 = gr.Markdown("For descriptor extraction")
|
244 |
out_msg3 = gr.Markdown("Followed by VLAD assignment")
|
@@ -247,21 +430,40 @@ with gr.Blocks() as demo:
|
|
247 |
# ---- Utility functions ----
|
248 |
# A wrapper to batch the images
|
249 |
def batch_images(data):
|
250 |
-
sv = data[nimg_s]
|
251 |
images: List[np.ndarray] = [data[imgs[k]] \
|
252 |
for k in range(sv)]
|
253 |
return images
|
254 |
# A wrapper to unbatch images (and pad to max)
|
255 |
-
def unbatch_images(imgs_batch):
|
256 |
ret = [gr.Image.update(visible=False) \
|
257 |
for _ in range(max_num_imgs)]
|
258 |
if imgs_batch is None or len(imgs_batch) == 0:
|
259 |
return ret
|
260 |
-
for i
|
261 |
-
|
|
|
|
|
|
|
262 |
ret[i] = gr.Image.update(img_np, visible=True)
|
263 |
return ret
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
# ---- Main pipeline ----
|
266 |
# Get the VLAD cluster assignment images on click
|
267 |
bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\
|
@@ -272,12 +474,113 @@ with gr.Blocks() as demo:
|
|
272 |
.then(get_ca_images,
|
273 |
[desc_assignments, patch_descs, blend_alpha],
|
274 |
[imgs_batch, out_msg4])\
|
275 |
-
.then(unbatch_images, imgs_batch, imgs2)
|
276 |
-
# If the blending changes now, update the cluster images
|
277 |
-
blend_alpha.
|
278 |
[desc_assignments, patch_descs, blend_alpha],
|
279 |
[imgs_batch, out_msg4])\
|
280 |
-
.then(unbatch_images, imgs_batch, imgs2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
print("Interface build completed")
|
283 |
|
|
|
23 |
- https://www.gradio.app/guides/blocks-and-event-listeners
|
24 |
"""
|
25 |
|
26 |
+
# A markdown string shown at the top of the app
|
27 |
+
header_markdown = """
|
28 |
+
# AnyLoc Demo
|
29 |
+
|
30 |
+
\| [Website](https://anyloc.github.io/) \| \
|
31 |
+
[GitHub](https://github.com/AnyLoc/AnyLoc) \| \
|
32 |
+
[YouTube](https://youtu.be/ITo8rMInatk) \|
|
33 |
+
|
34 |
+
|
35 |
+
This space contains a collection of demos for AnyLoc. Each demo is a \
|
36 |
+
self-contained application in the tabs below. The following \
|
37 |
+
applications are included
|
38 |
+
|
39 |
+
1. **GeM t-SNE Projection**: Upload a set of images and see where \
|
40 |
+
they land on a t-SNE projection of GeM descriptors from many \
|
41 |
+
domains. This can be used to guide domain selection (from a few \
|
42 |
+
representative images).
|
43 |
+
2. **Cluster Visualization**: This visualizes the VLAD cluster \
|
44 |
+
assignments for the patch descriptors. You need to select the \
|
45 |
+
domain for loading VLAD cluster centers (vocabulary).
|
46 |
+
|
47 |
+
We do **not** save any images uploaded to the demo. Some errors may \
|
48 |
+
leave a log. We do not collect any information about the user.
|
49 |
+
|
50 |
+
🥳 Thanks to HuggingFace for providing a free GPU for this demo.
|
51 |
+
|
52 |
+
"""
|
53 |
+
|
54 |
# %%
|
55 |
import os
|
56 |
import gradio as gr
|
|
|
63 |
from torchvision.transforms import functional as T
|
64 |
from PIL import Image
|
65 |
import matplotlib.pyplot as plt
|
66 |
+
from sklearn.manifold import TSNE
|
67 |
import distinctipy as dipy
|
68 |
+
import joblib
|
69 |
from typing import Literal, List
|
70 |
import gradio as gr
|
71 |
import time
|
72 |
import glob
|
73 |
import shutil
|
74 |
+
import matplotlib.pyplot as plt
|
75 |
from copy import deepcopy
|
76 |
# DINOv2 imports
|
77 |
from utilities import DinoV2ExtractFeatures
|
|
|
93 |
num_c: int = 8
|
94 |
cache_dir: str = _ex("./cache") # Directory containing program cache
|
95 |
max_img_size: int = 1024 # Image resolution (max dim/size)
|
96 |
+
max_num_imgs: int = 16 # Max number of images to upload
|
97 |
share: bool = False # Share application using .gradio link
|
98 |
|
99 |
# Verify inputs
|
100 |
assert os.path.isdir(cache_dir), "Cache directory not found"
|
101 |
|
102 |
+
|
103 |
# %%
|
104 |
# Model and transforms
|
105 |
print("Loading DINO model")
|
106 |
+
# extractor = None # FIXME: For quick testing only
|
107 |
extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet,
|
108 |
device=device)
|
109 |
print("DINO model loaded")
|
110 |
# VLAD path (directory)
|
111 |
ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}"
|
112 |
vc_dir = os.path.join(cache_dir, "vocabulary", ext_s)
|
113 |
+
assert os.path.isdir(vc_dir), f"VLAD directory: {vc_dir} not found"
|
114 |
+
# GeM path (cache)
|
115 |
+
gem_cf = os.path.join(cache_dir, "gem_cache", "result_dino_v2.gz")
|
116 |
+
assert os.path.isfile(gem_cf), f"GeM cache: {gem_cf} not found"
|
117 |
+
gem_cache = joblib.load(gem_cf)
|
118 |
+
assert gem_cache["model"]["type"] == dino_model
|
119 |
+
assert gem_cache["model"]["layer"] == desc_layer
|
120 |
+
assert gem_cache["model"]["facet"] == desc_facet
|
121 |
+
fig = plt.figure() # Main figure
|
122 |
+
fig.clear()
|
123 |
# Base image transformations
|
124 |
base_tf = tvf.Compose([
|
125 |
tvf.ToTensor(),
|
|
|
157 |
pr(0, desc="Extracting descriptors")
|
158 |
patch_descs = []
|
159 |
for i, img in enumerate(imgs_batch):
|
160 |
+
if img is None:
|
161 |
+
print(f"Image {i+1} is None")
|
162 |
+
continue
|
163 |
# Convert to PIL image
|
164 |
pil_img = Image.fromarray(img)
|
165 |
img_pt = base_tf(pil_img).to(device)
|
|
|
184 |
ret = extractor(img_pt).cpu() # [1, n_p, d]
|
185 |
patch_descs.append({"img": pil_img, "descs": ret})
|
186 |
pr((i+1) / len(imgs_batch))
|
187 |
+
pr(1.0)
|
188 |
return patch_descs, \
|
189 |
f"Descriptors extracted for {len(imgs_batch)} images"
|
190 |
|
|
|
219 |
def get_ca_images(desc_assignments, patch_descs, alpha,
|
220 |
pr = gr.Progress()):
|
221 |
if desc_assignments is None or len(desc_assignments) == 0:
|
222 |
+
if not 0 <= alpha <= 1:
|
223 |
+
return None, f"Invalid alpha value: {alpha} (should be "\
|
224 |
+
"between 0 and 1)"
|
225 |
+
return None, "First load the images"
|
226 |
c_colors = dipy.get_colors(num_c, rng=928,
|
227 |
colorblind_type="Deuteranomaly")
|
228 |
np_colors = (np.array(c_colors) * 255).astype(np.uint8)
|
|
|
251 |
return res_imgs, "Cluster assignment images generated"
|
252 |
|
253 |
|
254 |
+
# %%
|
255 |
+
# Get GeM descriptors from cache
|
256 |
+
def get_gem_descs_cache(use_d, pr = gr.Progress()):
|
257 |
+
use_d: List[str] = use_d
|
258 |
+
if len(use_d) == 0:
|
259 |
+
return "Select at least one domain", None
|
260 |
+
else:
|
261 |
+
use_d = [d.lower() for d in use_d]
|
262 |
+
indoor_datasets = ["baidu_datasets", "gardens", "17places"]
|
263 |
+
urban_datasets = ["pitts30k", "st_lucia", "Oxford"]
|
264 |
+
aerial_datasets = ["Tartan_GNSS_test_rotated",
|
265 |
+
"Tartan_GNSS_test_notrotated", "VPAir"]
|
266 |
+
pr(0, desc="Loading GeM descriptors from cache")
|
267 |
+
gem_descs = {
|
268 |
+
"labels": [],
|
269 |
+
"descs": [],
|
270 |
+
}
|
271 |
+
for i, ds in enumerate(gem_cache["data"]):
|
272 |
+
# GeM descriptors from data: n_desc, desc_dim
|
273 |
+
d: np.ndarray = gem_cache["data"][ds]["descriptors"]
|
274 |
+
if ds in indoor_datasets and "indoor" in use_d:
|
275 |
+
gem_descs["labels"].extend(["indoor"] * d.shape[0])
|
276 |
+
elif ds in urban_datasets and "urban" in use_d:
|
277 |
+
gem_descs["labels"].extend(["urban"] * d.shape[0])
|
278 |
+
elif ds in aerial_datasets and "aerial" in use_d:
|
279 |
+
gem_descs["labels"].extend(["aerial"] * d.shape[0])
|
280 |
+
else:
|
281 |
+
continue
|
282 |
+
gem_descs["descs"].append(d)
|
283 |
+
pr((i+1) / len(gem_cache["data"]))
|
284 |
+
gem_descs["descs"] = np.concatenate(gem_descs["descs"], axis=0)
|
285 |
+
pr(1.0)
|
286 |
+
return "GeM descriptors loaded from cache", gem_descs
|
287 |
+
|
288 |
+
|
289 |
+
# %%
|
290 |
+
# Get GeM pooled features of the uploaded images
|
291 |
+
def get_add_gem_descs(imgs_batch, gem_descs, pr = gr.Progress()):
|
292 |
+
imgs_batch: List[np.ndarray] = imgs_batch
|
293 |
+
gem_descs: dict = gem_descs
|
294 |
+
pr(0, desc="Extracting GeM descriptors")
|
295 |
+
num_imgs_extracted = 0
|
296 |
+
for i, img in enumerate(imgs_batch):
|
297 |
+
if img is None:
|
298 |
+
print(f"Image {i+1} is None")
|
299 |
+
continue
|
300 |
+
# Convert to PIL image
|
301 |
+
pil_img = Image.fromarray(img)
|
302 |
+
img_pt = base_tf(pil_img).to(device)
|
303 |
+
if max(img_pt.shape[-2:]) > max_img_size:
|
304 |
+
print(f"Image {i+1}: {img_pt.shape[-2:]}, outside")
|
305 |
+
c, h, w = img_pt.shape
|
306 |
+
# Maintain aspect ratio
|
307 |
+
if h == max(img_pt.shape[-2:]):
|
308 |
+
w = int(w * max_img_size / h)
|
309 |
+
h = max_img_size
|
310 |
+
else:
|
311 |
+
h = int(h * max_img_size / w)
|
312 |
+
w = max_img_size
|
313 |
+
img_pt = T.resize(img_pt, (h, w),
|
314 |
+
interpolation=T.InterpolationMode.BICUBIC)
|
315 |
+
pil_img = pil_img.resize((w, h)) # Backup
|
316 |
+
# Make image patchable
|
317 |
+
c, h, w = img_pt.shape
|
318 |
+
h_new, w_new = (h // 14) * 14, (w // 14) * 14
|
319 |
+
img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...]
|
320 |
+
# Extract descriptors
|
321 |
+
ret = extractor(img_pt).cpu() # [1, n_p, d]
|
322 |
+
# Get the GeM pooled descriptor
|
323 |
+
x = torch.mean(ret**3, dim=-2)
|
324 |
+
g_res = x.to(torch.complex64) ** (1/3)
|
325 |
+
g_res = torch.abs(g_res) * torch.sign(x) # [1, d]
|
326 |
+
g_res = g_res.numpy()
|
327 |
+
# Add to state
|
328 |
+
gem_descs["labels"].append(f"Image{i+1}")
|
329 |
+
gem_descs["descs"] = np.concatenate([gem_descs["descs"],
|
330 |
+
g_res])
|
331 |
+
num_imgs_extracted += 1
|
332 |
+
pr((i+1) / len(imgs_batch))
|
333 |
+
pr(1.0)
|
334 |
+
gem_descs["num_uimgs"] = num_imgs_extracted
|
335 |
+
return gem_descs, "GeM descriptors extracted"
|
336 |
+
|
337 |
+
|
338 |
+
# %%
|
339 |
+
# Apply tSNE to the GeM descriptors
|
340 |
+
def get_tsne_fm_gem(gem_descs, pr = gr.Progress()):
|
341 |
+
pr(0, desc="Applying tSNE to GeM descriptors")
|
342 |
+
desc_all: np.ndarray = gem_descs["descs"] # [n, d_dim]
|
343 |
+
labels_all: List[str] = gem_descs["labels"] # [n]
|
344 |
+
# tSNE projection
|
345 |
+
tsne = TSNE(n_components=2, random_state=30, perplexity=50,
|
346 |
+
learning_rate=200, init='random')
|
347 |
+
desc_2d = tsne.fit_transform(desc_all)
|
348 |
+
# Result
|
349 |
+
tsne_pts = {
|
350 |
+
"labels": labels_all,
|
351 |
+
"pts": desc_2d,
|
352 |
+
"num_uimgs": gem_descs["num_uimgs"], # Number of user imgs
|
353 |
+
}
|
354 |
+
pr(1.0)
|
355 |
+
return tsne_pts, "tSNE projection done"
|
356 |
+
|
357 |
+
|
358 |
+
# %%
|
359 |
+
# Plot tSNE to matplotlib figure
|
360 |
+
def plot_tsne(tsne_pts):
|
361 |
+
colors = {
|
362 |
+
"aerial": (80/255, 0/255, 80/255),
|
363 |
+
"indoor": ( 0/255, 76/255, 204/255),
|
364 |
+
"urban": ( 0/255, 204/255, 0/255),
|
365 |
+
}
|
366 |
+
ni = int(tsne_pts["num_uimgs"])
|
367 |
+
# Custom colors for user images
|
368 |
+
ucs = dipy.get_colors(ni, exclude_colors=list(colors.values())\
|
369 |
+
.extend([(0, 0, 0), (1, 1, 1)]),
|
370 |
+
colorblind_type="Deuteranomaly")
|
371 |
+
for i in range(ni):
|
372 |
+
colors[f"Image{i+1}"] = ucs[i]
|
373 |
+
fig.clear()
|
374 |
+
gs = fig.add_gridspec(1, 1)
|
375 |
+
ax = fig.add_subplot(gs[0, 0])
|
376 |
+
ax.set_title("tSNE Projection")
|
377 |
+
for i, domain in enumerate(list(colors.keys())):
|
378 |
+
pts = tsne_pts["pts"][np.array(tsne_pts["labels"]) == domain]
|
379 |
+
if domain.startswith("Image"):
|
380 |
+
m = "x"
|
381 |
+
else:
|
382 |
+
m = "o"
|
383 |
+
ax.scatter(pts[:, 0], pts[:, 1], label=domain, marker=m,
|
384 |
+
color=colors[domain])
|
385 |
+
# Put legend at the bottom of axis
|
386 |
+
ax.legend()
|
387 |
+
ax.set_xticks([])
|
388 |
+
ax.set_yticks([])
|
389 |
+
fig.set_tight_layout(True)
|
390 |
+
# fig.set_tight_layout(True)
|
391 |
+
return fig, "tSNE plot created"
|
392 |
+
|
393 |
+
|
394 |
# %%
|
395 |
print("Interface build started")
|
396 |
+
|
397 |
+
|
398 |
+
# Tab for VLAD cluster assignment visualization
|
399 |
+
def tab_cluster_viz():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
d_vals = [k.title() for k in DOMAINS]
|
401 |
+
domain = gr.Radio(d_vals, value=d_vals[0], label="Domain",
|
402 |
+
info="The domain of images (for loading VLAD vocabulary)")
|
403 |
+
nimg_s = gr.Number(2, label="How many images?", precision=0,
|
404 |
+
info=f"Between '1' and '{max_num_imgs}' images. Press "\
|
405 |
+
"enter/return to register")
|
406 |
with gr.Row(): # Dynamic row (images in columns)
|
407 |
imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
|
408 |
+
for i in range(int(nimg_s.value))] + \
|
409 |
[gr.Image(visible=False) \
|
410 |
+
for _ in range(max_num_imgs - int(nimg_s.value))]
|
411 |
for i, img in enumerate(imgs): # Set image as "input"
|
412 |
img.change(lambda _: None, img)
|
413 |
with gr.Row(): # Dynamic row of output (cluster) images
|
414 |
imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}",
|
415 |
visible=False) for i in range(max_num_imgs)]
|
416 |
+
nimg_s.submit(var_num_img, nimg_s, imgs)
|
417 |
+
blend_alpha = gr.Number(0.4, label="Blending alpha",
|
418 |
+
info="Weight for cluster centers (between 0 and 1). "\
|
419 |
+
"Higher (close to 1) means greater emphasis on cluster "\
|
420 |
+
"visibility. Lower (closer to 0) will show the "\
|
421 |
+
"underlying image more. "\
|
422 |
+
"Press enter/return to register")
|
423 |
bttn1 = gr.Button("Click Me!") # Cluster assignment
|
424 |
+
gr.Markdown("### Status strings")
|
425 |
out_msg1 = gr.Markdown("Select domain and upload images")
|
426 |
out_msg2 = gr.Markdown("For descriptor extraction")
|
427 |
out_msg3 = gr.Markdown("Followed by VLAD assignment")
|
|
|
430 |
# ---- Utility functions ----
|
431 |
# A wrapper to batch the images
|
432 |
def batch_images(data):
|
433 |
+
sv = int(data[nimg_s])
|
434 |
images: List[np.ndarray] = [data[imgs[k]] \
|
435 |
for k in range(sv)]
|
436 |
return images
|
437 |
# A wrapper to unbatch images (and pad to max)
|
438 |
+
def unbatch_images(imgs_batch, nimg):
|
439 |
ret = [gr.Image.update(visible=False) \
|
440 |
for _ in range(max_num_imgs)]
|
441 |
if imgs_batch is None or len(imgs_batch) == 0:
|
442 |
return ret
|
443 |
+
for i in range(nimg): # nimg only to match input layout
|
444 |
+
if i < len(imgs_batch):
|
445 |
+
img_np = np.array(imgs_batch[i])
|
446 |
+
else:
|
447 |
+
img_np = None
|
448 |
ret[i] = gr.Image.update(img_np, visible=True)
|
449 |
return ret
|
450 |
|
451 |
+
# ---- Examples ----
|
452 |
+
# Two images from each domain
|
453 |
+
gr.Examples(
|
454 |
+
[
|
455 |
+
["Aerial", 2,
|
456 |
+
"ex_aerial_nardo-air_db-42.png",
|
457 |
+
"ex_aerial_nardo-air_qu-42.png",],
|
458 |
+
["Indoor", 2,
|
459 |
+
"ex_indoor_17places_db-75.jpg",
|
460 |
+
"ex_indoor_17places_qu-75.jpg"],
|
461 |
+
["Urban", 2,
|
462 |
+
"ex_urban_oxford_db-75.png",
|
463 |
+
"ex_urban_oxford_qu-75.png"],],
|
464 |
+
[domain, nimg_s, *imgs],
|
465 |
+
)
|
466 |
+
|
467 |
# ---- Main pipeline ----
|
468 |
# Get the VLAD cluster assignment images on click
|
469 |
bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\
|
|
|
474 |
.then(get_ca_images,
|
475 |
[desc_assignments, patch_descs, blend_alpha],
|
476 |
[imgs_batch, out_msg4])\
|
477 |
+
.then(unbatch_images, [imgs_batch, nimg_s], imgs2)
|
478 |
+
# If the blending changes now, update the cluster images only
|
479 |
+
blend_alpha.submit(get_ca_images,
|
480 |
[desc_assignments, patch_descs, blend_alpha],
|
481 |
[imgs_batch, out_msg4])\
|
482 |
+
.then(unbatch_images, [imgs_batch, nimg_s], imgs2)
|
483 |
+
|
484 |
+
|
485 |
+
# Tab for GeM t-SNE projection plot
|
486 |
+
def tab_gem_tsne():
|
487 |
+
d_vals = [k.title() for k in DOMAINS]
|
488 |
+
dms = gr.CheckboxGroup(d_vals, value=d_vals, label="Domains",
|
489 |
+
info="The domains to use for the t-SNE projection")
|
490 |
+
nimg_s = gr.Number(2, label="How many images?", precision=0,
|
491 |
+
info=f"Between '1' and '{max_num_imgs}' images. Press "\
|
492 |
+
"enter/return to register")
|
493 |
+
with gr.Row(): # Dynamic row (images in columns)
|
494 |
+
imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
|
495 |
+
for i in range(int(nimg_s.value))] + \
|
496 |
+
[gr.Image(visible=False) \
|
497 |
+
for _ in range(max_num_imgs - int(nimg_s.value))]
|
498 |
+
for i, img in enumerate(imgs): # Set image as "input"
|
499 |
+
img.change(lambda _: None, img)
|
500 |
+
nimg_s.submit(var_num_img, nimg_s, imgs)
|
501 |
+
tsne_plot = gr.Plot(None, label="tSNE Plot")
|
502 |
+
out_msg1 = gr.Markdown("Select domains")
|
503 |
+
out_msg2 = gr.Markdown("Upload images")
|
504 |
+
out_msg3 = gr.Markdown("Wait for tSNE plots")
|
505 |
+
|
506 |
+
# A wrapper to batch the images
|
507 |
+
def batch_images(data):
|
508 |
+
sv = int(data[nimg_s])
|
509 |
+
# images: List[np.ndarray] = [data[imgs[k]] \
|
510 |
+
# for k in range(sv)]
|
511 |
+
images: List[np.ndarray] = []
|
512 |
+
for k in range(sv):
|
513 |
+
img = data[imgs[k]]
|
514 |
+
if img is None:
|
515 |
+
return None, f"Image {k+1} is None!"
|
516 |
+
images.append(img)
|
517 |
+
return images, "Images batched"
|
518 |
+
|
519 |
+
bttn1 = gr.Button("Click Me!")
|
520 |
+
|
521 |
+
# ---- Main pipeline ----
|
522 |
+
# Get the tSNE plot
|
523 |
+
bttn1.click(get_gem_descs_cache, dms, [out_msg1, gem_descs])\
|
524 |
+
.then(batch_images, {nimg_s, *imgs, imgs_batch},
|
525 |
+
[imgs_batch, out_msg2])\
|
526 |
+
.then(get_add_gem_descs, [imgs_batch, gem_descs],
|
527 |
+
[gem_descs, out_msg2])\
|
528 |
+
.then(get_tsne_fm_gem, gem_descs, [tsne_pts, out_msg3])\
|
529 |
+
.then(plot_tsne, tsne_pts, [tsne_plot, out_msg3])
|
530 |
+
|
531 |
+
|
532 |
+
# Build the interface
|
533 |
+
with gr.Blocks() as demo:
|
534 |
+
# Main header
|
535 |
+
gr.Markdown(header_markdown)
|
536 |
+
|
537 |
+
# ---- Helper functions ----
|
538 |
+
# Variable number of input images (show/hide UI image array)
|
539 |
+
def var_num_img(s):
|
540 |
+
n = int(s) # Slider (string) value as int
|
541 |
+
assert 1 <= n <= max_num_imgs, f"Invalid num of images: {n}!"
|
542 |
+
return [gr.Image.update(label=f"Image {i+1}", visible=True) \
|
543 |
+
for i in range(n)] \
|
544 |
+
+ [gr.Image.update(visible=False) \
|
545 |
+
for _ in range(max_num_imgs - n)]
|
546 |
+
|
547 |
+
# ---- State declarations ----
|
548 |
+
vlad = gr.State() # VLAD object
|
549 |
+
desc_assignments = gr.State() # Cluster assignments
|
550 |
+
imgs_batch = gr.State() # Images as batch
|
551 |
+
patch_descs = gr.State() # Patch descriptors
|
552 |
+
gem_descs = gr.State() # GeM descriptors (of each state)
|
553 |
+
tsne_pts = gr.State() # tSNE points
|
554 |
+
|
555 |
+
# ---- All UI elements ----
|
556 |
+
with gr.Tab("GeM t-SNE Projection"):
|
557 |
+
gr.Markdown(
|
558 |
+
"""
|
559 |
+
## GeM t-SNE Projection
|
560 |
+
|
561 |
+
Select the domains (toggle visibility) for t-SNE plot. \
|
562 |
+
Enter the number of images to upload and upload images. \
|
563 |
+
Then click the button to get the t-SNE plot.
|
564 |
+
|
565 |
+
""")
|
566 |
+
tab_gem_tsne()
|
567 |
+
|
568 |
+
with gr.Tab("Cluster Visualization"):
|
569 |
+
gr.Markdown(
|
570 |
+
"""
|
571 |
+
## Cluster Visualizations
|
572 |
+
|
573 |
+
Select the domain for the images (all should be from the \
|
574 |
+
same domain). Enter the number of images to upload. \
|
575 |
+
Upload the images. Then click the button to get the \
|
576 |
+
cluster assignment images.
|
577 |
+
|
578 |
+
You can also directly click on one of the examples (at \
|
579 |
+
the bottom) to load the data and then click the button \
|
580 |
+
to get the cluster assignment images.
|
581 |
+
|
582 |
+
""")
|
583 |
+
tab_cluster_viz()
|
584 |
|
585 |
print("Interface build completed")
|
586 |
|
cache/gem_cache/result_dino_v2.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d27de9e8552eeeca6fd3f99fca135429169adcf926388d65eb44bb1ba9391f5
|
3 |
+
size 1990740
|
cache/gem_cache/result_dino_v2_tsne.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:77d978ec8fcb5f10075db999d037925846f2cbad522e5d0266ddcbfb50cd99de
|
3 |
+
size 3192
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ matplotlib
|
|
8 |
distinctipy
|
9 |
einops
|
10 |
fast_pytorch_kmeans
|
11 |
-
|
|
|
|
8 |
distinctipy
|
9 |
einops
|
10 |
fast_pytorch_kmeans
|
11 |
+
joblib==1.2.0
|
12 |
+
sklearn==1.0.2
|