|
|
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
import numpy as np |
|
import torch |
|
import open_clip |
|
|
|
from config import CFG |
|
|
|
@dataclass |
|
class PromptConfig: |
|
model_name: str = "ViT-B-16" |
|
pretrained: str = "openai" |
|
|
|
kiln_prompts: tuple = ( |
|
"brick kiln, aerial view", |
|
"zigzag brick kiln, satellite image", |
|
"bull's trench brick kiln, aerial imagery", |
|
"fixed chimney brick kiln, aerial photo", |
|
"industrial brick kiln site with drying yards, satellite view", |
|
) |
|
|
|
PROMPT_CFG = PromptConfig() |
|
|
|
EMBED_DIR = (CFG.outputs_dir / "embeds") |
|
EMBED_DIR.mkdir(parents=True, exist_ok=True) |
|
EMBED_PATH = EMBED_DIR / "kiln_embed.npz" |
|
|
|
def compute_and_save_kiln_embed(device: str | None = None): |
|
""" |
|
Encodes multiple kiln prompts with CLIP's text encoder, averages them into one |
|
unit-norm vector t_kiln, and saves to NPZ. |
|
""" |
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model, _, _ = open_clip.create_model_and_transforms( |
|
PROMPT_CFG.model_name, pretrained=PROMPT_CFG.pretrained, device=device |
|
) |
|
tokenizer = open_clip.get_tokenizer(PROMPT_CFG.model_name) |
|
model.eval() |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
with torch.no_grad(): |
|
tokens = tokenizer(list(PROMPT_CFG.kiln_prompts)).to(device) |
|
t_list = model.encode_text(tokens) |
|
t_list = t_list / t_list.norm(dim=-1, keepdim=True) |
|
t_kiln = t_list.mean(dim=0, keepdim=True) |
|
t_kiln = t_kiln / t_kiln.norm(dim=-1, keepdim=True) |
|
|
|
|
|
np.savez_compressed( |
|
EMBED_PATH, |
|
t_kiln=t_kiln.float().cpu().numpy(), |
|
kiln_prompts=np.array(PROMPT_CFG.kiln_prompts, dtype=object), |
|
model_name=np.array(PROMPT_CFG.model_name, dtype=object), |
|
pretrained=np.array(PROMPT_CFG.pretrained, dtype=object), |
|
) |
|
print(f"[prompts] Saved kiln anchor → {EMBED_PATH}") |
|
|
|
def load_kiln_embed(to_device: str | None = None) -> torch.Tensor: |
|
""" |
|
Loads the kiln anchor t_kiln as a torch tensor of shape [D], unit-norm. |
|
""" |
|
data = np.load(EMBED_PATH, allow_pickle=True) |
|
arr = data["t_kiln"] |
|
t = torch.from_numpy(arr).float().squeeze(0) |
|
if to_device is None: |
|
to_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
return t.to(to_device) |
|
|
|
if __name__ == "__main__": |
|
compute_and_save_kiln_embed() |
|
|