File size: 9,657 Bytes
fd05d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# src/infer_with_centroid.py
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as T

from config import CFG
from prompt import load_kiln_embed
from model import build_model
from head import SimilarityHead

# -----------------------------
# Transforms (match training)
# -----------------------------


import matplotlib.pyplot as plt
import numpy as np

def save_heatmap(sal, out_path="outputs/heatmap.png"):
    sal_np = sal.squeeze().cpu().numpy()  # [H, W]
    plt.imshow(sal_np, cmap="jet", interpolation="nearest")
    plt.colorbar()
    plt.title("Attention Saliency Heatmap")
    plt.savefig(out_path, bbox_inches="tight")
    plt.close()


def save_heatmap_overlay(pil_img, sal, out_path="outputs/heatmap_overlay.png", alpha=0.45):

    W0, H0 = pil_img.size
    # upsample saliency to image size
    sal_up = F.interpolate(sal, size=(H0, W0), mode="bilinear", align_corners=False)
    h = sal_up.squeeze().detach().cpu().numpy()  # [H0, W0]
    # normalize to 0..1
    h = (h - h.min()) / (h.max() - h.min() + 1e-6)
    # colormap -> RGBA [H0, W0, 4]
    cmap = plt.get_cmap("jet")
    h_rgba = (cmap(h) * 255).astype(np.uint8)
    heat = Image.fromarray(h_rgba).convert("RGBA")

    base = pil_img.convert("RGBA")
    overlay = Image.blend(base, heat, alpha)
    overlay.save(out_path)


def model_transform():
    return T.Compose([
        T.Resize((CFG.image_size, CFG.image_size)),
        T.ToTensor(),
        T.ConvertImageDtype(torch.float32),
        T.Normalize(mean=CFG.mean, std=CFG.std),
    ])

def display_transform():
    return T.Resize((CFG.image_size, CFG.image_size))

def load_checkpoint(ckpt_path: Path):
    return torch.load(ckpt_path, map_location="cpu")

def build_model_and_head(device, ckpt):
    args = ckpt.get("args", {})
    last_k = args.get("last_k", 2)
    model_name = args.get("model_name", "ViT-B-16")
    pretrained = args.get("pretrained", "openai")

    model = build_model(
        model_name=model_name,
        pretrained=pretrained,
        train_last_k_blocks=last_k,
        train_final_norm=True,
        finetune_logit_scale=False,
        device=device,
    )
    model.load_state_dict(ckpt["model_state"], strict=False)

    head = SimilarityHead(init_alpha=10.0, init_beta=0.0).to(device)
    head.load_state_dict(ckpt["head_state"], strict=False)

    model.eval(); head.eval()
    return model, head

def find_openclip_visual_tower(model: torch.nn.Module):
    for m in model.modules():
        if hasattr(m, "transformer") and hasattr(m.transformer, "resblocks"):
            if hasattr(m, "class_embedding") and hasattr(m, "positional_embedding"):
                return m
    raise AttributeError("Could not find a ViT visual tower (transformer.resblocks) inside the model. Debugger")

def _patch_block_attention_capture(blk):
    if hasattr(blk, "_orig_attention"):
        return
    blk._orig_attention = blk.attention

    def attention_with_capture(q_x: torch.Tensor, k_x: torch.Tensor = None, v_x: torch.Tensor = None,
                               attn_mask: torch.Tensor = None, **kwargs):
        k_x = k_x if k_x is not None else q_x
        v_x = v_x if v_x is not None else q_x
        #will have to keep dtype compatible with MHA here for future ref
        attn_mask2 = attn_mask.to(q_x.dtype) if attn_mask is not None else None

        out, w = blk.attn(
            q_x, k_x, v_x,
            need_weights=True,
            average_attn_weights=False,
            attn_mask=attn_mask2
        )

        if w.dim() == 4:
            blk.__last_attn__ = w.detach()
        elif w.dim() == 3:
            # normalize to [B, T, S]
            if w.shape[0] == q_x.shape[0]:
                w_bts = w
            elif w.shape[1] == q_x.shape[0]:      
                w_bts = w.permute(1, 0, 2)
            else:
                w_bts = w.unsqueeze(0) if w.shape[0] != 0 else w
            blk.__last_attn__ = w_bts.unsqueeze(1).detach()  # -> [B, 1, T, S] implicit single head
        else:
            blk.__last_attn__ = None

        return out

    blk.attention = attention_with_capture

def _enable_attn_capture(visual):
    for blk in visual.transformer.resblocks:
        _patch_block_attention_capture(blk)

@torch.no_grad()
def vit_centroid_from_attention(model, x):
    
    B = x.shape[0]
    visual = find_openclip_visual_tower(model)
    _enable_attn_capture(visual)
    _ = model.encode_image(x)  # forward passing rn

    A = []
    for blk in visual.transformer.resblocks:
        w = getattr(blk, "__last_attn__", None)
        if w is None:
            continue
        if w.dim() == 4 and w.shape[0] in (1, B):
            a = w.mean(1)
        else:
            if w.dim() == 3 and w.shape[0] == B:
                a = w
            elif w.dim() == 3 and w.shape[1] == B:
                a = w.permute(1, 0, 2)
            else:
                continue 
        a = a / (a.sum(-1, keepdim=True) + 1e-6)  # [B, T, S]
        Ttok = a.size(-1)
        I = torch.eye(Ttok, device=a.device).unsqueeze(0).expand(a.size(0), Ttok, Ttok)
        a = 0.1 * I + 0.9 * a
        A.append(a)

    if not A:
        raise RuntimeError("No attention captured. Backend likely changed shapes; print shapes inside attention_with_capture to inspect.")
    # joint attention will be here
    R = A[0]
    for a in A[1:]:
        R = a @ R

    cls_to_tokens = R[:, 0, 1:]

    if hasattr(visual, "patch_size"):
        ph, pw = (visual.patch_size if isinstance(visual.patch_size, tuple) else (visual.patch_size, visual.patch_size))
    else:
        ph = pw = 16

    B, _, H_img, W_img = x.shape
    H = H_img // ph
    W = W_img // pw

    sal = cls_to_tokens.view(B, 1, H, W)
    sal = F.relu(sal)
    sal = sal / (sal.sum(dim=[2, 3], keepdim=True) + 1e-6)

    ys = torch.arange(H, device=sal.device).view(1, 1, H, 1).float()
    xs = torch.arange(W, device=sal.device).view(1, 1, 1, W).float()
    cy = (sal * ys).sum(dim=[2, 3]) 
    cx = (sal * xs).sum(dim=[2, 3])
    cx_px = (cx + 0.5) * pw
    cy_px = (cy + 0.5) * ph
    return cx_px.squeeze(1), cy_px.squeeze(1), sal

def annotate_text(pil_img, text):
    img = pil_img.convert("RGB").copy()
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("arial.ttf", 18)
    except:
        font = ImageFont.load_default()
    tw, th = draw.textbbox((0, 0), text, font=font)[2:]
    pad = 6
    box = [pad, pad, pad + tw + 2 * pad, pad + th + 2 * pad]
    draw.rectangle(box, fill=(0, 0, 0, 180))
    draw.text((pad * 2, pad * 2), text, fill=(255, 255, 255), font=font)
    return img

def draw_centroid(pil_img, cx, cy, r=8, color=(255, 0, 0), w=2):
    img = pil_img.convert("RGB").copy()
    d = ImageDraw.Draw(img)
    d.ellipse([cx - r, cy - r, cx + r, cy + r], outline=color, width=w)
    d.line([(cx - 2 * r, cy), (cx + 2 * r, cy)], fill=color, width=w)
    d.line([(cx, cy - 2 * r), (cx, cy + 2 * r)], fill=color, width=w)
    return img

@torch.no_grad()
def predict_prob_and_centroid(pil_img, device, t_kiln, model, head, thresh=0.5):
    x = model_transform()(pil_img).unsqueeze(0).to(device)
    # classification prob
    z = model.encode_image(x)
    z = z / (z.norm(dim=-1, keepdim=True) + 1e-8)
    s = torch.einsum("bd,d->b", z, t_kiln)
    logits = head(s)
    prob = torch.sigmoid(logits)[0].item()

    cx_r, cy_r, _ = vit_centroid_from_attention(model, x)
    
    return prob, float(cx_r[0]), float(cy_r[0]), x.shape[-1], x.shape[-2]


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=str, default="checkpoints/best.pt")
    ap.add_argument("--image", type=str, required=True)
    ap.add_argument("--out", type=str, default="outputs/preview_centroid.jpg")
    ap.add_argument("--threshold", type=float, default=0.5)
    ap.add_argument("--annotate_on_original", action="store_true")
    args = ap.parse_args()

    ckpt_path = Path(args.ckpt)
    img_path = Path(args.image)
    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    assert ckpt_path.exists(), f"Checkpoint not found: {ckpt_path}"
    assert img_path.exists(), f"Image not found: {img_path}"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    t_kiln = load_kiln_embed(to_device=device)
    t_kiln = t_kiln / (t_kiln.norm() + 1e-8)

    ckpt = load_checkpoint(ckpt_path)
    model, head = build_model_and_head(device, ckpt)

    pil = Image.open(img_path).convert("RGB")
    prob, cx_r, cy_r, W_r, H_r = predict_prob_and_centroid(pil, device, t_kiln, model, head, args.threshold)
    _, _, sal = vit_centroid_from_attention(model, model_transform()(pil).unsqueeze(0).to(device))
    save_heatmap(sal, "outputs/heatmap.png")
    save_heatmap_overlay(pil, sal, "outputs/heatmap_overlay.png", alpha=0.45)
    pred = int(prob >= args.threshold)

    disp = display_transform()(pil)
    disp = annotate_text(disp, f"pred: {'KILN' if pred==1 else 'NOT'}  p={prob:.3f}")

    if pred == 1:
        if args.annotate_on_original:
            W0, H0 = pil.size
            sx = W0 / W_r
            sy = H0 / H_r
            cx0, cy0 = cx_r * sx, cy_r * sy
            out_img = draw_centroid(pil, cx0, cy0)
            print(f"centroid_px(original): ({cx0:.1f}, {cy0:.1f})  prob={prob:.3f}")
        else:
            out_img = draw_centroid(disp, cx_r, cy_r)
            print(f"centroid_px(resized {W_r}x{H_r}): ({cx_r:.1f}, {cy_r:.1f})  prob={prob:.3f}")
    else:
        out_img = disp
        print(f"no kiln (p={prob:.3f})")

    out_img.save(out_path)
    print(f"[saved] {out_path}")

if __name__ == "__main__":
    main()