|
|
|
import argparse |
|
from pathlib import Path |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image, ImageDraw, ImageFont |
|
import torchvision.transforms as T |
|
|
|
from config import CFG |
|
from prompt import load_kiln_embed |
|
from model import build_model |
|
from head import SimilarityHead |
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
def save_heatmap(sal, out_path="outputs/heatmap.png"): |
|
sal_np = sal.squeeze().cpu().numpy() |
|
plt.imshow(sal_np, cmap="jet", interpolation="nearest") |
|
plt.colorbar() |
|
plt.title("Attention Saliency Heatmap") |
|
plt.savefig(out_path, bbox_inches="tight") |
|
plt.close() |
|
|
|
|
|
def save_heatmap_overlay(pil_img, sal, out_path="outputs/heatmap_overlay.png", alpha=0.45): |
|
|
|
W0, H0 = pil_img.size |
|
|
|
sal_up = F.interpolate(sal, size=(H0, W0), mode="bilinear", align_corners=False) |
|
h = sal_up.squeeze().detach().cpu().numpy() |
|
|
|
h = (h - h.min()) / (h.max() - h.min() + 1e-6) |
|
|
|
cmap = plt.get_cmap("jet") |
|
h_rgba = (cmap(h) * 255).astype(np.uint8) |
|
heat = Image.fromarray(h_rgba).convert("RGBA") |
|
|
|
base = pil_img.convert("RGBA") |
|
overlay = Image.blend(base, heat, alpha) |
|
overlay.save(out_path) |
|
|
|
|
|
def model_transform(): |
|
return T.Compose([ |
|
T.Resize((CFG.image_size, CFG.image_size)), |
|
T.ToTensor(), |
|
T.ConvertImageDtype(torch.float32), |
|
T.Normalize(mean=CFG.mean, std=CFG.std), |
|
]) |
|
|
|
def display_transform(): |
|
return T.Resize((CFG.image_size, CFG.image_size)) |
|
|
|
def load_checkpoint(ckpt_path: Path): |
|
return torch.load(ckpt_path, map_location="cpu") |
|
|
|
def build_model_and_head(device, ckpt): |
|
args = ckpt.get("args", {}) |
|
last_k = args.get("last_k", 2) |
|
model_name = args.get("model_name", "ViT-B-16") |
|
pretrained = args.get("pretrained", "openai") |
|
|
|
model = build_model( |
|
model_name=model_name, |
|
pretrained=pretrained, |
|
train_last_k_blocks=last_k, |
|
train_final_norm=True, |
|
finetune_logit_scale=False, |
|
device=device, |
|
) |
|
model.load_state_dict(ckpt["model_state"], strict=False) |
|
|
|
head = SimilarityHead(init_alpha=10.0, init_beta=0.0).to(device) |
|
head.load_state_dict(ckpt["head_state"], strict=False) |
|
|
|
model.eval(); head.eval() |
|
return model, head |
|
|
|
def find_openclip_visual_tower(model: torch.nn.Module): |
|
for m in model.modules(): |
|
if hasattr(m, "transformer") and hasattr(m.transformer, "resblocks"): |
|
if hasattr(m, "class_embedding") and hasattr(m, "positional_embedding"): |
|
return m |
|
raise AttributeError("Could not find a ViT visual tower (transformer.resblocks) inside the model. Debugger") |
|
|
|
def _patch_block_attention_capture(blk): |
|
if hasattr(blk, "_orig_attention"): |
|
return |
|
blk._orig_attention = blk.attention |
|
|
|
def attention_with_capture(q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None, |
|
attn_mask: torch.Tensor = None, **kwargs): |
|
k_x = k_x if k_x is not None else q_x |
|
v_x = v_x if v_x is not None else q_x |
|
|
|
attn_mask2 = attn_mask.to(q_x.dtype) if attn_mask is not None else None |
|
|
|
out, w = blk.attn( |
|
q_x, k_x, v_x, |
|
need_weights=True, |
|
average_attn_weights=False, |
|
attn_mask=attn_mask2 |
|
) |
|
|
|
if w.dim() == 4: |
|
blk.__last_attn__ = w.detach() |
|
elif w.dim() == 3: |
|
|
|
if w.shape[0] == q_x.shape[0]: |
|
w_bts = w |
|
elif w.shape[1] == q_x.shape[0]: |
|
w_bts = w.permute(1, 0, 2) |
|
else: |
|
w_bts = w.unsqueeze(0) if w.shape[0] != 0 else w |
|
blk.__last_attn__ = w_bts.unsqueeze(1).detach() |
|
else: |
|
blk.__last_attn__ = None |
|
|
|
return out |
|
|
|
blk.attention = attention_with_capture |
|
|
|
def _enable_attn_capture(visual): |
|
for blk in visual.transformer.resblocks: |
|
_patch_block_attention_capture(blk) |
|
|
|
@torch.no_grad() |
|
def vit_centroid_from_attention(model, x): |
|
|
|
B = x.shape[0] |
|
visual = find_openclip_visual_tower(model) |
|
_enable_attn_capture(visual) |
|
_ = model.encode_image(x) |
|
|
|
A = [] |
|
for blk in visual.transformer.resblocks: |
|
w = getattr(blk, "__last_attn__", None) |
|
if w is None: |
|
continue |
|
if w.dim() == 4 and w.shape[0] in (1, B): |
|
a = w.mean(1) |
|
else: |
|
if w.dim() == 3 and w.shape[0] == B: |
|
a = w |
|
elif w.dim() == 3 and w.shape[1] == B: |
|
a = w.permute(1, 0, 2) |
|
else: |
|
continue |
|
a = a / (a.sum(-1, keepdim=True) + 1e-6) |
|
Ttok = a.size(-1) |
|
I = torch.eye(Ttok, device=a.device).unsqueeze(0).expand(a.size(0), Ttok, Ttok) |
|
a = 0.1 * I + 0.9 * a |
|
A.append(a) |
|
|
|
if not A: |
|
raise RuntimeError("No attention captured. Backend likely changed shapes; print shapes inside attention_with_capture to inspect.") |
|
|
|
R = A[0] |
|
for a in A[1:]: |
|
R = a @ R |
|
|
|
cls_to_tokens = R[:, 0, 1:] |
|
|
|
if hasattr(visual, "patch_size"): |
|
ph, pw = (visual.patch_size if isinstance(visual.patch_size, tuple) else (visual.patch_size, visual.patch_size)) |
|
else: |
|
ph = pw = 16 |
|
|
|
B, _, H_img, W_img = x.shape |
|
H = H_img // ph |
|
W = W_img // pw |
|
|
|
sal = cls_to_tokens.view(B, 1, H, W) |
|
sal = F.relu(sal) |
|
sal = sal / (sal.sum(dim=[2, 3], keepdim=True) + 1e-6) |
|
|
|
ys = torch.arange(H, device=sal.device).view(1, 1, H, 1).float() |
|
xs = torch.arange(W, device=sal.device).view(1, 1, 1, W).float() |
|
cy = (sal * ys).sum(dim=[2, 3]) |
|
cx = (sal * xs).sum(dim=[2, 3]) |
|
cx_px = (cx + 0.5) * pw |
|
cy_px = (cy + 0.5) * ph |
|
return cx_px.squeeze(1), cy_px.squeeze(1), sal |
|
|
|
def annotate_text(pil_img, text): |
|
img = pil_img.convert("RGB").copy() |
|
draw = ImageDraw.Draw(img) |
|
try: |
|
font = ImageFont.truetype("arial.ttf", 18) |
|
except: |
|
font = ImageFont.load_default() |
|
tw, th = draw.textbbox((0, 0), text, font=font)[2:] |
|
pad = 6 |
|
box = [pad, pad, pad + tw + 2 * pad, pad + th + 2 * pad] |
|
draw.rectangle(box, fill=(0, 0, 0, 180)) |
|
draw.text((pad * 2, pad * 2), text, fill=(255, 255, 255), font=font) |
|
return img |
|
|
|
def draw_centroid(pil_img, cx, cy, r=8, color=(255, 0, 0), w=2): |
|
img = pil_img.convert("RGB").copy() |
|
d = ImageDraw.Draw(img) |
|
d.ellipse([cx - r, cy - r, cx + r, cy + r], outline=color, width=w) |
|
d.line([(cx - 2 * r, cy), (cx + 2 * r, cy)], fill=color, width=w) |
|
d.line([(cx, cy - 2 * r), (cx, cy + 2 * r)], fill=color, width=w) |
|
return img |
|
|
|
@torch.no_grad() |
|
def predict_prob_and_centroid(pil_img, device, t_kiln, model, head, thresh=0.5): |
|
x = model_transform()(pil_img).unsqueeze(0).to(device) |
|
|
|
z = model.encode_image(x) |
|
z = z / (z.norm(dim=-1, keepdim=True) + 1e-8) |
|
s = torch.einsum("bd,d->b", z, t_kiln) |
|
logits = head(s) |
|
prob = torch.sigmoid(logits)[0].item() |
|
|
|
cx_r, cy_r, _ = vit_centroid_from_attention(model, x) |
|
|
|
return prob, float(cx_r[0]), float(cy_r[0]), x.shape[-1], x.shape[-2] |
|
|
|
|
|
def main(): |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument("--ckpt", type=str, default="checkpoints/best.pt") |
|
ap.add_argument("--image", type=str, required=True) |
|
ap.add_argument("--out", type=str, default="outputs/preview_centroid.jpg") |
|
ap.add_argument("--threshold", type=float, default=0.5) |
|
ap.add_argument("--annotate_on_original", action="store_true") |
|
args = ap.parse_args() |
|
|
|
ckpt_path = Path(args.ckpt) |
|
img_path = Path(args.image) |
|
out_path = Path(args.out) |
|
out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}" |
|
assert img_path.exists(), f"Image not found: {img_path}" |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
t_kiln = load_kiln_embed(to_device=device) |
|
t_kiln = t_kiln / (t_kiln.norm() + 1e-8) |
|
|
|
ckpt = load_checkpoint(ckpt_path) |
|
model, head = build_model_and_head(device, ckpt) |
|
|
|
pil = Image.open(img_path).convert("RGB") |
|
prob, cx_r, cy_r, W_r, H_r = predict_prob_and_centroid(pil, device, t_kiln, model, head, args.threshold) |
|
_, _, sal = vit_centroid_from_attention(model, model_transform()(pil).unsqueeze(0).to(device)) |
|
save_heatmap(sal, "outputs/heatmap.png") |
|
save_heatmap_overlay(pil, sal, "outputs/heatmap_overlay.png", alpha=0.45) |
|
pred = int(prob >= args.threshold) |
|
|
|
disp = display_transform()(pil) |
|
disp = annotate_text(disp, f"pred: {'KILN' if pred==1 else 'NOT'} p={prob:.3f}") |
|
|
|
if pred == 1: |
|
if args.annotate_on_original: |
|
W0, H0 = pil.size |
|
sx = W0 / W_r |
|
sy = H0 / H_r |
|
cx0, cy0 = cx_r * sx, cy_r * sy |
|
out_img = draw_centroid(pil, cx0, cy0) |
|
print(f"centroid_px(original): ({cx0:.1f}, {cy0:.1f}) prob={prob:.3f}") |
|
else: |
|
out_img = draw_centroid(disp, cx_r, cy_r) |
|
print(f"centroid_px(resized {W_r}x{H_r}): ({cx_r:.1f}, {cy_r:.1f}) prob={prob:.3f}") |
|
else: |
|
out_img = disp |
|
print(f"no kiln (p={prob:.3f})") |
|
|
|
out_img.save(out_path) |
|
print(f"[saved] {out_path}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|