# app_ijepa.py # Gradio UI for interactive I-JEPA patch cosine similarity # Fixed for modern Gradio .select API (evt passed as first arg) import io, math, urllib.request from functools import lru_cache from typing import Optional import gradio as gr import numpy as np from PIL import Image, ImageDraw import torch from torchvision import transforms from transformers import AutoModel from matplotlib import colormaps as cm # ---------------- Models ---------------- IJ_EPA_MODEL_IDS = [ "facebook/ijepa_vith14_1k", "facebook/ijepa_vith16_1k", "facebook/ijepa_vitg16_22k", ] SHORT_NAMES = { "facebook/ijepa_vith14_1k": "vith14_1k", "facebook/ijepa_vith16_1k": "vith16_1k", "facebook/ijepa_vitg16_22k": "vitg16_22k", } REVERSE_MAP = {v: k for k, v in SHORT_NAMES.items()} DEFAULT_MODEL = "vith14_1k" DEFAULT_URL = "http://images.cocodataset.org/val2017/000000039769.jpg" DEFAULT_OVERLAY_ALPHA = 0.55 DEFAULT_SHOW_GRID = True IJ_EPA_MEAN = [0.5, 0.5, 0.5] IJ_EPA_STD = [0.5, 0.5, 0.5] # ---------------- Utilities ---------------- def load_image_from_any(src: Optional[Image.Image], url: Optional[str]) -> Optional[Image.Image]: if url and str(url).lower().startswith(("http://", "https://")): with urllib.request.urlopen(url) as resp: data = resp.read() return Image.open(io.BytesIO(data)).convert("RGB") if isinstance(src, Image.Image): return src.convert("RGB") return None def pad_to_multiple(pil_img: Image.Image, multiple: int = 16): W, H = pil_img.size H_pad = int(math.ceil(H / multiple) * multiple) W_pad = int(math.ceil(W / multiple) * multiple) if (H_pad, W_pad) == (H, W): return pil_img canvas = Image.new("RGB", (W_pad, H_pad), (0,0,0)) canvas.paste(pil_img, (0,0)) return canvas def preprocess_no_resize(pil_img: Image.Image, multiple: int = 16): img_padded = pad_to_multiple(pil_img, multiple=multiple) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=IJ_EPA_MEAN, std=IJ_EPA_STD), ]) pixel_tensor = transform(img_padded).unsqueeze(0) disp_np = np.array(img_padded, dtype=np.uint8) return {"pixel_values": pixel_tensor}, disp_np def upsample_nearest(arr: np.ndarray, H: int, W: int, ps: int): if arr.ndim == 2: return arr.repeat(ps, 0).repeat(ps, 1) elif arr.ndim == 3: return arr.repeat(ps, 0).repeat(ps, 1).reshape(H, W, -1) raise ValueError def blend_overlay(base_uint8, overlay_rgb_float, alpha: float): base = base_uint8.astype(np.float32) over = (overlay_rgb_float * 255.0).astype(np.float32) out = (1.0 - alpha) * base + alpha * over return np.clip(out, 0, 255).astype(np.uint8) def draw_grid(img: Image.Image, rows: int, cols: int, ps: int): d = ImageDraw.Draw(img) for r in range(1, rows): d.line([(0, r*ps), (img.width, r*ps)], fill=(255,255,255), width=1) for c in range(1, cols): d.line([(c*ps, 0), (c*ps, img.height)], fill=(255,255,255), width=1) def rc_to_idx(r: int, c: int, cols: int): return r * cols + c def idx_to_rc(i: int, cols: int): return divmod(i, cols) # ---------------- Model cache ---------------- @lru_cache(maxsize=3) def load_model_cached(full_model_id: str, device_str: str): model = AutoModel.from_pretrained(full_model_id, attn_implementation="sdpa").to(device_str) model.eval() return model def infer_patch_size(model, default: int = 16) -> int: if hasattr(model, "config") and hasattr(model.config, "patch_size"): ps = model.config.patch_size return int(ps[0]) if isinstance(ps, (tuple, list)) else int(ps) if hasattr(model, "patch_size"): ps = model.patch_size return int(ps[0]) if isinstance(ps, (tuple, list)) else int(ps) return default # ---------------- State ---------------- class PatchImageState: def __init__(self, pil_img, model, device_str, ps): self.ps = ps inputs, disp_np = preprocess_no_resize(pil_img, multiple=ps) self.disp = disp_np pv = inputs["pixel_values"].to(device_str) _, _, H, W = pv.shape self.H, self.W = H, W self.rows, self.cols = H // ps, W // ps with torch.no_grad(): out = model(pixel_values=pv, interpolate_pos_encoding=True) hs = out.last_hidden_state.squeeze(0).cpu().numpy() T, D = hs.shape n_patches = self.rows * self.cols n_special = T - n_patches self.X = hs[n_special:, :].reshape(-1, D) self.Xn = self.X / (np.linalg.norm(self.X, axis=1, keepdims=True) + 1e-8) # ---------------- Compute & render ---------------- def render_with_cosmap(st, cos_map, overlay_alpha, show_grid_flag, select_idx=None, best_idx=None): H, W, ps = st.H, st.W, st.ps rows, cols = st.rows, st.cols if cos_map is None: disp = np.full((rows, cols), 0.5, dtype=np.float32) else: disp = (cos_map - cos_map.min()) / (cos_map.ptp() + 1e-8) cmap = cm.get_cmap("magma") rgb = cmap(disp)[..., :3] if select_idx is not None: r, c = idx_to_rc(select_idx, cols) rgb[r, c, :] = np.array([1.0, 0.0, 0.0]) over_rgb_up = upsample_nearest(rgb, H, W, ps) blended = blend_overlay(st.disp, over_rgb_up, float(overlay_alpha)) pil = Image.fromarray(blended) if show_grid_flag: draw_grid(pil, rows, cols, ps) return pil def compute_self_and_cross(src, tgt, q_idx): qn = src.Xn[q_idx] cos_self = src.Xn @ qn cos_map_self = cos_self.reshape(src.rows, src.cols) cos_map_cross, best_idx = None, None if tgt: cos_cross = tgt.Xn @ qn cos_map_cross = cos_cross.reshape(tgt.rows, tgt.cols) best_idx = int(np.argmax(cos_cross)) return cos_map_self, cos_map_cross, best_idx # ---------------- Gradio bindings ---------------- def resolve_full_model_id(short_name): return REVERSE_MAP.get(short_name) def init_states(left_img_in, left_url, right_img_in, right_url, short_model, show_grid_flag, overlay_alpha): left_img = load_image_from_any(left_img_in, left_url) right_img = load_image_from_any(right_img_in, right_url) if left_img is None and right_img is None: left_img = load_image_from_any(None, DEFAULT_URL) full_model_id = resolve_full_model_id(short_model) device_str = "cuda" if torch.cuda.is_available() else "cpu" model = load_model_cached(full_model_id, device_str) ps = infer_patch_size(model, 16) left_state = PatchImageState(left_img, model, device_str, ps) if left_img else None right_state = PatchImageState(right_img, model, device_str, ps) if right_img else None active_side = 0 if left_state else 1 status = f"✔ Loaded {full_model_id} | ps={ps}" return None, None, left_state, right_state, active_side, ps, status def click_on(evt: gr.SelectData, which_side, left_state, right_state, active_side, ps, overlay_alpha, show_grid_flag): x, y = evt.index if which_side == "left" and left_state: r, c = int(y // ps), int(x // ps) q_idx = rc_to_idx(r, c, left_state.cols) cos_self, cos_cross, best_idx = compute_self_and_cross(left_state, right_state, q_idx) out_left = render_with_cosmap(left_state, cos_self, overlay_alpha, show_grid_flag, q_idx) out_right = render_with_cosmap(right_state, cos_cross, overlay_alpha, show_grid_flag, best_idx) if right_state else None return out_left, out_right if which_side == "right" and right_state: r, c = int(y // ps), int(x // ps) q_idx = rc_to_idx(r, c, right_state.cols) cos_self, cos_cross, best_idx = compute_self_and_cross(right_state, left_state, q_idx) out_right = render_with_cosmap(right_state, cos_self, overlay_alpha, show_grid_flag, q_idx) out_left = render_with_cosmap(left_state, cos_cross, overlay_alpha, show_grid_flag, best_idx) if left_state else None return out_left, out_right return None, None # ---------------- UI ---------------- with gr.Blocks() as demo: gr.Markdown("## I-JEPA Interactive Patch Cosine Similarity") with gr.Row(): with gr.Column(scale=1): model_dd = gr.Dropdown(choices=list(REVERSE_MAP.keys()), value=DEFAULT_MODEL, label="Model") show_grid = gr.Checkbox(value=DEFAULT_SHOW_GRID, label="Show grid") alpha = gr.Slider(0.0, 1.0, value=DEFAULT_OVERLAY_ALPHA, step=0.01, label="Overlay alpha") status = gr.Markdown("") with gr.Row(): with gr.Column(): left_url = gr.Textbox(label="Left image URL", value=DEFAULT_URL) left_img = gr.Image(type="pil", label="or upload (left)") left_view = gr.Image(type="pil", label="Left view") with gr.Column(): right_url = gr.Textbox(label="Right image URL (optional)") right_img = gr.Image(type="pil", label="or upload (right)") right_view = gr.Image(type="pil", label="Right view") left_state = gr.State() right_state = gr.State() active_side = gr.State(0) ps_st = gr.State(16) btn = gr.Button("Load / Refresh") btn.click( fn=init_states, inputs=[left_img, left_url, right_img, right_url, model_dd, show_grid, alpha], outputs=[left_view, right_view, left_state, right_state, active_side, ps_st, status], ) def handle_left(evt: gr.SelectData, ls, rs, as_, ps, a, sg): return click_on(evt, "left", ls, rs, as_, ps, a, sg) def handle_right(evt: gr.SelectData, ls, rs, as_, ps, a, sg): return click_on(evt, "right", ls, rs, as_, ps, a, sg) left_view.select( fn=handle_left, inputs=[left_state, right_state, active_side, ps_st, alpha, show_grid], outputs=[left_view, right_view], ) right_view.select( fn=handle_right, inputs=[left_state, right_state, active_side, ps_st, alpha, show_grid], outputs=[left_view, right_view], ) demo.queue().launch(ssr_mode=False, server_name="0.0.0.0", server_port=7860)