kiln-clip-vit-b-16 / prompt.py
sulemanhamdani's picture
Upload 4 files
fd05d86 verified
# src/prompts.py
from pathlib import Path
from dataclasses import dataclass
import numpy as np
import torch
import open_clip
from config import CFG
@dataclass
class PromptConfig:
model_name: str = "ViT-B-16" # keep this in sync with your vision backbone later
pretrained: str = "openai"
# Positive-only prompt variants (edit/extend as you wish)
kiln_prompts: tuple = (
"brick kiln, aerial view",
"zigzag brick kiln, satellite image",
"bull's trench brick kiln, aerial imagery",
"fixed chimney brick kiln, aerial photo",
"industrial brick kiln site with drying yards, satellite view",
)
PROMPT_CFG = PromptConfig()
EMBED_DIR = (CFG.outputs_dir / "embeds")
EMBED_DIR.mkdir(parents=True, exist_ok=True)
EMBED_PATH = EMBED_DIR / "kiln_embed.npz"
def compute_and_save_kiln_embed(device: str | None = None):
"""
Encodes multiple kiln prompts with CLIP's text encoder, averages them into one
unit-norm vector t_kiln, and saves to NPZ.
"""
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# load model just for the text tower; keep everything frozen
model, _, _ = open_clip.create_model_and_transforms(
PROMPT_CFG.model_name, pretrained=PROMPT_CFG.pretrained, device=device
)
tokenizer = open_clip.get_tokenizer(PROMPT_CFG.model_name)
model.eval()
for p in model.parameters():
p.requires_grad = False
with torch.no_grad():
tokens = tokenizer(list(PROMPT_CFG.kiln_prompts)).to(device)
t_list = model.encode_text(tokens) # [K, D]
t_list = t_list / t_list.norm(dim=-1, keepdim=True) # unit vectors
t_kiln = t_list.mean(dim=0, keepdim=True) # [1, D]
t_kiln = t_kiln / t_kiln.norm(dim=-1, keepdim=True) # re-normalize
# save
np.savez_compressed(
EMBED_PATH,
t_kiln=t_kiln.float().cpu().numpy(), # [1, D]
kiln_prompts=np.array(PROMPT_CFG.kiln_prompts, dtype=object),
model_name=np.array(PROMPT_CFG.model_name, dtype=object),
pretrained=np.array(PROMPT_CFG.pretrained, dtype=object),
)
print(f"[prompts] Saved kiln anchor → {EMBED_PATH}")
def load_kiln_embed(to_device: str | None = None) -> torch.Tensor:
"""
Loads the kiln anchor t_kiln as a torch tensor of shape [D], unit-norm.
"""
data = np.load(EMBED_PATH, allow_pickle=True)
arr = data["t_kiln"] # [1, D]
t = torch.from_numpy(arr).float().squeeze(0) # [D]
if to_device is None:
to_device = "cuda" if torch.cuda.is_available() else "cpu"
return t.to(to_device)
if __name__ == "__main__":
compute_and_save_kiln_embed()