# src/infer_with_centroid.py import argparse from pathlib import Path import torch import torch.nn.functional as F from PIL import Image, ImageDraw, ImageFont import torchvision.transforms as T from config import CFG from prompt import load_kiln_embed from model import build_model from head import SimilarityHead # ----------------------------- # Transforms (match training) # ----------------------------- import matplotlib.pyplot as plt import numpy as np def save_heatmap(sal, out_path="outputs/heatmap.png"): sal_np = sal.squeeze().cpu().numpy() # [H, W] plt.imshow(sal_np, cmap="jet", interpolation="nearest") plt.colorbar() plt.title("Attention Saliency Heatmap") plt.savefig(out_path, bbox_inches="tight") plt.close() def save_heatmap_overlay(pil_img, sal, out_path="outputs/heatmap_overlay.png", alpha=0.45): W0, H0 = pil_img.size # upsample saliency to image size sal_up = F.interpolate(sal, size=(H0, W0), mode="bilinear", align_corners=False) h = sal_up.squeeze().detach().cpu().numpy() # [H0, W0] # normalize to 0..1 h = (h - h.min()) / (h.max() - h.min() + 1e-6) # colormap -> RGBA [H0, W0, 4] cmap = plt.get_cmap("jet") h_rgba = (cmap(h) * 255).astype(np.uint8) heat = Image.fromarray(h_rgba).convert("RGBA") base = pil_img.convert("RGBA") overlay = Image.blend(base, heat, alpha) overlay.save(out_path) def model_transform(): return T.Compose([ T.Resize((CFG.image_size, CFG.image_size)), T.ToTensor(), T.ConvertImageDtype(torch.float32), T.Normalize(mean=CFG.mean, std=CFG.std), ]) def display_transform(): return T.Resize((CFG.image_size, CFG.image_size)) def load_checkpoint(ckpt_path: Path): return torch.load(ckpt_path, map_location="cpu") def build_model_and_head(device, ckpt): args = ckpt.get("args", {}) last_k = args.get("last_k", 2) model_name = args.get("model_name", "ViT-B-16") pretrained = args.get("pretrained", "openai") model = build_model( model_name=model_name, pretrained=pretrained, train_last_k_blocks=last_k, train_final_norm=True, finetune_logit_scale=False, device=device, ) model.load_state_dict(ckpt["model_state"], strict=False) head = SimilarityHead(init_alpha=10.0, init_beta=0.0).to(device) head.load_state_dict(ckpt["head_state"], strict=False) model.eval(); head.eval() return model, head def find_openclip_visual_tower(model: torch.nn.Module): for m in model.modules(): if hasattr(m, "transformer") and hasattr(m.transformer, "resblocks"): if hasattr(m, "class_embedding") and hasattr(m, "positional_embedding"): return m raise AttributeError("Could not find a ViT visual tower (transformer.resblocks) inside the model. Debugger") def _patch_block_attention_capture(blk): if hasattr(blk, "_orig_attention"): return blk._orig_attention = blk.attention def attention_with_capture(q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, attn_mask: torch.Tensor = None, **kwargs): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x #will have to keep dtype compatible with MHA here for future ref attn_mask2 = attn_mask.to(q_x.dtype) if attn_mask is not None else None out, w = blk.attn( q_x, k_x, v_x, need_weights=True, average_attn_weights=False, attn_mask=attn_mask2 ) if w.dim() == 4: blk.__last_attn__ = w.detach() elif w.dim() == 3: # normalize to [B, T, S] if w.shape[0] == q_x.shape[0]: w_bts = w elif w.shape[1] == q_x.shape[0]: w_bts = w.permute(1, 0, 2) else: w_bts = w.unsqueeze(0) if w.shape[0] != 0 else w blk.__last_attn__ = w_bts.unsqueeze(1).detach() # -> [B, 1, T, S] implicit single head else: blk.__last_attn__ = None return out blk.attention = attention_with_capture def _enable_attn_capture(visual): for blk in visual.transformer.resblocks: _patch_block_attention_capture(blk) @torch.no_grad() def vit_centroid_from_attention(model, x): B = x.shape[0] visual = find_openclip_visual_tower(model) _enable_attn_capture(visual) _ = model.encode_image(x) # forward passing rn A = [] for blk in visual.transformer.resblocks: w = getattr(blk, "__last_attn__", None) if w is None: continue if w.dim() == 4 and w.shape[0] in (1, B): a = w.mean(1) else: if w.dim() == 3 and w.shape[0] == B: a = w elif w.dim() == 3 and w.shape[1] == B: a = w.permute(1, 0, 2) else: continue a = a / (a.sum(-1, keepdim=True) + 1e-6) # [B, T, S] Ttok = a.size(-1) I = torch.eye(Ttok, device=a.device).unsqueeze(0).expand(a.size(0), Ttok, Ttok) a = 0.1 * I + 0.9 * a A.append(a) if not A: raise RuntimeError("No attention captured. Backend likely changed shapes; print shapes inside attention_with_capture to inspect.") # joint attention will be here R = A[0] for a in A[1:]: R = a @ R cls_to_tokens = R[:, 0, 1:] if hasattr(visual, "patch_size"): ph, pw = (visual.patch_size if isinstance(visual.patch_size, tuple) else (visual.patch_size, visual.patch_size)) else: ph = pw = 16 B, _, H_img, W_img = x.shape H = H_img // ph W = W_img // pw sal = cls_to_tokens.view(B, 1, H, W) sal = F.relu(sal) sal = sal / (sal.sum(dim=[2, 3], keepdim=True) + 1e-6) ys = torch.arange(H, device=sal.device).view(1, 1, H, 1).float() xs = torch.arange(W, device=sal.device).view(1, 1, 1, W).float() cy = (sal * ys).sum(dim=[2, 3]) cx = (sal * xs).sum(dim=[2, 3]) cx_px = (cx + 0.5) * pw cy_px = (cy + 0.5) * ph return cx_px.squeeze(1), cy_px.squeeze(1), sal def annotate_text(pil_img, text): img = pil_img.convert("RGB").copy() draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("arial.ttf", 18) except: font = ImageFont.load_default() tw, th = draw.textbbox((0, 0), text, font=font)[2:] pad = 6 box = [pad, pad, pad + tw + 2 * pad, pad + th + 2 * pad] draw.rectangle(box, fill=(0, 0, 0, 180)) draw.text((pad * 2, pad * 2), text, fill=(255, 255, 255), font=font) return img def draw_centroid(pil_img, cx, cy, r=8, color=(255, 0, 0), w=2): img = pil_img.convert("RGB").copy() d = ImageDraw.Draw(img) d.ellipse([cx - r, cy - r, cx + r, cy + r], outline=color, width=w) d.line([(cx - 2 * r, cy), (cx + 2 * r, cy)], fill=color, width=w) d.line([(cx, cy - 2 * r), (cx, cy + 2 * r)], fill=color, width=w) return img @torch.no_grad() def predict_prob_and_centroid(pil_img, device, t_kiln, model, head, thresh=0.5): x = model_transform()(pil_img).unsqueeze(0).to(device) # classification prob z = model.encode_image(x) z = z / (z.norm(dim=-1, keepdim=True) + 1e-8) s = torch.einsum("bd,d->b", z, t_kiln) logits = head(s) prob = torch.sigmoid(logits)[0].item() cx_r, cy_r, _ = vit_centroid_from_attention(model, x) return prob, float(cx_r[0]), float(cy_r[0]), x.shape[-1], x.shape[-2] def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", type=str, default="checkpoints/best.pt") ap.add_argument("--image", type=str, required=True) ap.add_argument("--out", type=str, default="outputs/preview_centroid.jpg") ap.add_argument("--threshold", type=float, default=0.5) ap.add_argument("--annotate_on_original", action="store_true") args = ap.parse_args() ckpt_path = Path(args.ckpt) img_path = Path(args.image) out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}" assert img_path.exists(), f"Image not found: {img_path}" device = "cuda" if torch.cuda.is_available() else "cpu" t_kiln = load_kiln_embed(to_device=device) t_kiln = t_kiln / (t_kiln.norm() + 1e-8) ckpt = load_checkpoint(ckpt_path) model, head = build_model_and_head(device, ckpt) pil = Image.open(img_path).convert("RGB") prob, cx_r, cy_r, W_r, H_r = predict_prob_and_centroid(pil, device, t_kiln, model, head, args.threshold) _, _, sal = vit_centroid_from_attention(model, model_transform()(pil).unsqueeze(0).to(device)) save_heatmap(sal, "outputs/heatmap.png") save_heatmap_overlay(pil, sal, "outputs/heatmap_overlay.png", alpha=0.45) pred = int(prob >= args.threshold) disp = display_transform()(pil) disp = annotate_text(disp, f"pred: {'KILN' if pred==1 else 'NOT'} p={prob:.3f}") if pred == 1: if args.annotate_on_original: W0, H0 = pil.size sx = W0 / W_r sy = H0 / H_r cx0, cy0 = cx_r * sx, cy_r * sy out_img = draw_centroid(pil, cx0, cy0) print(f"centroid_px(original): ({cx0:.1f}, {cy0:.1f}) prob={prob:.3f}") else: out_img = draw_centroid(disp, cx_r, cy_r) print(f"centroid_px(resized {W_r}x{H_r}): ({cx_r:.1f}, {cy_r:.1f}) prob={prob:.3f}") else: out_img = disp print(f"no kiln (p={prob:.3f})") out_img.save(out_path) print(f"[saved] {out_path}") if __name__ == "__main__": main()