kiln-clip-vit-b-16 / infer_with_centroid.py
sulemanhamdani's picture
Upload 4 files
fd05d86 verified
# src/infer_with_centroid.py
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as T
from config import CFG
from prompt import load_kiln_embed
from model import build_model
from head import SimilarityHead
# -----------------------------
# Transforms (match training)
# -----------------------------
import matplotlib.pyplot as plt
import numpy as np
def save_heatmap(sal, out_path="outputs/heatmap.png"):
sal_np = sal.squeeze().cpu().numpy() # [H, W]
plt.imshow(sal_np, cmap="jet", interpolation="nearest")
plt.colorbar()
plt.title("Attention Saliency Heatmap")
plt.savefig(out_path, bbox_inches="tight")
plt.close()
def save_heatmap_overlay(pil_img, sal, out_path="outputs/heatmap_overlay.png", alpha=0.45):
W0, H0 = pil_img.size
# upsample saliency to image size
sal_up = F.interpolate(sal, size=(H0, W0), mode="bilinear", align_corners=False)
h = sal_up.squeeze().detach().cpu().numpy() # [H0, W0]
# normalize to 0..1
h = (h - h.min()) / (h.max() - h.min() + 1e-6)
# colormap -> RGBA [H0, W0, 4]
cmap = plt.get_cmap("jet")
h_rgba = (cmap(h) * 255).astype(np.uint8)
heat = Image.fromarray(h_rgba).convert("RGBA")
base = pil_img.convert("RGBA")
overlay = Image.blend(base, heat, alpha)
overlay.save(out_path)
def model_transform():
return T.Compose([
T.Resize((CFG.image_size, CFG.image_size)),
T.ToTensor(),
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=CFG.mean, std=CFG.std),
])
def display_transform():
return T.Resize((CFG.image_size, CFG.image_size))
def load_checkpoint(ckpt_path: Path):
return torch.load(ckpt_path, map_location="cpu")
def build_model_and_head(device, ckpt):
args = ckpt.get("args", {})
last_k = args.get("last_k", 2)
model_name = args.get("model_name", "ViT-B-16")
pretrained = args.get("pretrained", "openai")
model = build_model(
model_name=model_name,
pretrained=pretrained,
train_last_k_blocks=last_k,
train_final_norm=True,
finetune_logit_scale=False,
device=device,
)
model.load_state_dict(ckpt["model_state"], strict=False)
head = SimilarityHead(init_alpha=10.0, init_beta=0.0).to(device)
head.load_state_dict(ckpt["head_state"], strict=False)
model.eval(); head.eval()
return model, head
def find_openclip_visual_tower(model: torch.nn.Module):
for m in model.modules():
if hasattr(m, "transformer") and hasattr(m.transformer, "resblocks"):
if hasattr(m, "class_embedding") and hasattr(m, "positional_embedding"):
return m
raise AttributeError("Could not find a ViT visual tower (transformer.resblocks) inside the model. Debugger")
def _patch_block_attention_capture(blk):
if hasattr(blk, "_orig_attention"):
return
blk._orig_attention = blk.attention
def attention_with_capture(q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None,
attn_mask: torch.Tensor = None, **kwargs):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
#will have to keep dtype compatible with MHA here for future ref
attn_mask2 = attn_mask.to(q_x.dtype) if attn_mask is not None else None
out, w = blk.attn(
q_x, k_x, v_x,
need_weights=True,
average_attn_weights=False,
attn_mask=attn_mask2
)
if w.dim() == 4:
blk.__last_attn__ = w.detach()
elif w.dim() == 3:
# normalize to [B, T, S]
if w.shape[0] == q_x.shape[0]:
w_bts = w
elif w.shape[1] == q_x.shape[0]:
w_bts = w.permute(1, 0, 2)
else:
w_bts = w.unsqueeze(0) if w.shape[0] != 0 else w
blk.__last_attn__ = w_bts.unsqueeze(1).detach() # -> [B, 1, T, S] implicit single head
else:
blk.__last_attn__ = None
return out
blk.attention = attention_with_capture
def _enable_attn_capture(visual):
for blk in visual.transformer.resblocks:
_patch_block_attention_capture(blk)
@torch.no_grad()
def vit_centroid_from_attention(model, x):
B = x.shape[0]
visual = find_openclip_visual_tower(model)
_enable_attn_capture(visual)
_ = model.encode_image(x) # forward passing rn
A = []
for blk in visual.transformer.resblocks:
w = getattr(blk, "__last_attn__", None)
if w is None:
continue
if w.dim() == 4 and w.shape[0] in (1, B):
a = w.mean(1)
else:
if w.dim() == 3 and w.shape[0] == B:
a = w
elif w.dim() == 3 and w.shape[1] == B:
a = w.permute(1, 0, 2)
else:
continue
a = a / (a.sum(-1, keepdim=True) + 1e-6) # [B, T, S]
Ttok = a.size(-1)
I = torch.eye(Ttok, device=a.device).unsqueeze(0).expand(a.size(0), Ttok, Ttok)
a = 0.1 * I + 0.9 * a
A.append(a)
if not A:
raise RuntimeError("No attention captured. Backend likely changed shapes; print shapes inside attention_with_capture to inspect.")
# joint attention will be here
R = A[0]
for a in A[1:]:
R = a @ R
cls_to_tokens = R[:, 0, 1:]
if hasattr(visual, "patch_size"):
ph, pw = (visual.patch_size if isinstance(visual.patch_size, tuple) else (visual.patch_size, visual.patch_size))
else:
ph = pw = 16
B, _, H_img, W_img = x.shape
H = H_img // ph
W = W_img // pw
sal = cls_to_tokens.view(B, 1, H, W)
sal = F.relu(sal)
sal = sal / (sal.sum(dim=[2, 3], keepdim=True) + 1e-6)
ys = torch.arange(H, device=sal.device).view(1, 1, H, 1).float()
xs = torch.arange(W, device=sal.device).view(1, 1, 1, W).float()
cy = (sal * ys).sum(dim=[2, 3])
cx = (sal * xs).sum(dim=[2, 3])
cx_px = (cx + 0.5) * pw
cy_px = (cy + 0.5) * ph
return cx_px.squeeze(1), cy_px.squeeze(1), sal
def annotate_text(pil_img, text):
img = pil_img.convert("RGB").copy()
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 18)
except:
font = ImageFont.load_default()
tw, th = draw.textbbox((0, 0), text, font=font)[2:]
pad = 6
box = [pad, pad, pad + tw + 2 * pad, pad + th + 2 * pad]
draw.rectangle(box, fill=(0, 0, 0, 180))
draw.text((pad * 2, pad * 2), text, fill=(255, 255, 255), font=font)
return img
def draw_centroid(pil_img, cx, cy, r=8, color=(255, 0, 0), w=2):
img = pil_img.convert("RGB").copy()
d = ImageDraw.Draw(img)
d.ellipse([cx - r, cy - r, cx + r, cy + r], outline=color, width=w)
d.line([(cx - 2 * r, cy), (cx + 2 * r, cy)], fill=color, width=w)
d.line([(cx, cy - 2 * r), (cx, cy + 2 * r)], fill=color, width=w)
return img
@torch.no_grad()
def predict_prob_and_centroid(pil_img, device, t_kiln, model, head, thresh=0.5):
x = model_transform()(pil_img).unsqueeze(0).to(device)
# classification prob
z = model.encode_image(x)
z = z / (z.norm(dim=-1, keepdim=True) + 1e-8)
s = torch.einsum("bd,d->b", z, t_kiln)
logits = head(s)
prob = torch.sigmoid(logits)[0].item()
cx_r, cy_r, _ = vit_centroid_from_attention(model, x)
return prob, float(cx_r[0]), float(cy_r[0]), x.shape[-1], x.shape[-2]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", type=str, default="checkpoints/best.pt")
ap.add_argument("--image", type=str, required=True)
ap.add_argument("--out", type=str, default="outputs/preview_centroid.jpg")
ap.add_argument("--threshold", type=float, default=0.5)
ap.add_argument("--annotate_on_original", action="store_true")
args = ap.parse_args()
ckpt_path = Path(args.ckpt)
img_path = Path(args.image)
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}"
assert img_path.exists(), f"Image not found: {img_path}"
device = "cuda" if torch.cuda.is_available() else "cpu"
t_kiln = load_kiln_embed(to_device=device)
t_kiln = t_kiln / (t_kiln.norm() + 1e-8)
ckpt = load_checkpoint(ckpt_path)
model, head = build_model_and_head(device, ckpt)
pil = Image.open(img_path).convert("RGB")
prob, cx_r, cy_r, W_r, H_r = predict_prob_and_centroid(pil, device, t_kiln, model, head, args.threshold)
_, _, sal = vit_centroid_from_attention(model, model_transform()(pil).unsqueeze(0).to(device))
save_heatmap(sal, "outputs/heatmap.png")
save_heatmap_overlay(pil, sal, "outputs/heatmap_overlay.png", alpha=0.45)
pred = int(prob >= args.threshold)
disp = display_transform()(pil)
disp = annotate_text(disp, f"pred: {'KILN' if pred==1 else 'NOT'} p={prob:.3f}")
if pred == 1:
if args.annotate_on_original:
W0, H0 = pil.size
sx = W0 / W_r
sy = H0 / H_r
cx0, cy0 = cx_r * sx, cy_r * sy
out_img = draw_centroid(pil, cx0, cy0)
print(f"centroid_px(original): ({cx0:.1f}, {cy0:.1f}) prob={prob:.3f}")
else:
out_img = draw_centroid(disp, cx_r, cy_r)
print(f"centroid_px(resized {W_r}x{H_r}): ({cx_r:.1f}, {cy_r:.1f}) prob={prob:.3f}")
else:
out_img = disp
print(f"no kiln (p={prob:.3f})")
out_img.save(out_path)
print(f"[saved] {out_path}")
if __name__ == "__main__":
main()