Safetensors
clip
zer0int commited on
Commit
dfad2ea
Β·
verified Β·
1 Parent(s): 8afdb75

Unleash KO-CLIP

Browse files
Files changed (1) hide show
  1. README.md +139 -3
README.md CHANGED
@@ -1,3 +1,139 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - SPRIGHT-T2I/spright_coco
5
+ - zer0int/CLIP-KO-Adversarial-Train-Typo-Attack
6
+ base_model:
7
+ - openai/clip-vit-base-patch32
8
+ ---
9
+ # CLIP-KO: Knocking Out Typographic Attacks in CLIP πŸ’ͺπŸ€–
10
+ ### Less vulnerability, much better performance! πŸ€—
11
+ ❀️ this CLIP? [Donate](https://ko-fi.com/zer0int) if you can / want. TY!
12
+
13
+ # πŸ”₯ CLIP-KO ViT-B/32 (vit-base-patch32)
14
+ - πŸ“ Read the [paper](https://github.com/zer0int/CLIP-fine-tune/blob/CLIP-vision/KO-CLIP-teaser/KO-CLIP-paper-final.pdf) (PDF) here.
15
+ - πŸ€“ Wanna fine-tune yourself? Get the [code](https://github.com/zer0int/CLIP-fine-tune) on my GitHub.
16
+
17
+
18
+ ----
19
+ <details>
20
+ <summary>πŸ‘‰ CLICK ME to expand example benchmark code βš‘πŸ’»</summary>
21
+
22
+ ```
23
+ from datasets import load_dataset
24
+ from transformers import CLIPModel, CLIPProcessor
25
+ import torch
26
+ from PIL import Image
27
+ from tqdm import tqdm
28
+ import pandas as pd
29
+
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ # BLISS / SCAM Typographic Attack Dataset
33
+ # https://huggingface.co/datasets/BLISS-e-V/SCAM
34
+ ds = load_dataset("BLISS-e-V/SCAM", split="train")
35
+
36
+ # Benchmark pre-trained model against my fine-tune
37
+ model_variants = [
38
+ ("OpenAI ", "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"),
39
+ ("KO-CLIP", "zer0int/CLIP-KO-ViT-B-32-TypoAttack", "zer0int/CLIP-KO-ViT-B-32-TypoAttack"),
40
+ ]
41
+
42
+ models = {}
43
+ for name, model_path, processor_path in model_variants:
44
+ model = CLIPModel.from_pretrained(model_path).to(device).float()
45
+ processor = CLIPProcessor.from_pretrained(processor_path)
46
+ models[name] = (model, processor)
47
+
48
+ for variant in ["NoSCAM", "SCAM", "SynthSCAM"]:
49
+ print(f"\n=== Evaluating var.: {variant} ===")
50
+ idxs = [i for i, v in enumerate(ds['id']) if v.startswith(variant)]
51
+ if not idxs:
52
+ print(f" No samples for {variant}")
53
+ continue
54
+ subset = [ds[i] for i in idxs]
55
+
56
+ for model_name, (model, processor) in models.items():
57
+ results = []
58
+ for entry in tqdm(subset, desc=f"{model_name}", ncols=30, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} |"):
59
+ img = entry['image']
60
+ object_label = entry['object_label']
61
+ attack_word = entry['attack_word']
62
+
63
+ texts = [f"a photo of a {object_label}", f"a photo of a {attack_word}"]
64
+ inputs = processor(
65
+ text=texts,
66
+ images=img,
67
+ return_tensors="pt",
68
+ padding=True
69
+ )
70
+ for k in inputs:
71
+ if isinstance(inputs[k], torch.Tensor):
72
+ inputs[k] = inputs[k].to(device)
73
+
74
+ with torch.no_grad():
75
+ outputs = model(**inputs)
76
+ image_features = outputs.image_embeds
77
+ text_features = outputs.text_embeds
78
+
79
+ logits = image_features @ text_features.T
80
+ probs = logits.softmax(dim=-1).cpu().numpy().flatten()
81
+ pred_idx = probs.argmax()
82
+ pred_label = [object_label, attack_word][pred_idx]
83
+ is_correct = (pred_label == object_label)
84
+
85
+ results.append({
86
+ "id": entry['id'],
87
+ "object_label": object_label,
88
+ "attack_word": attack_word,
89
+ "pred_label": pred_label,
90
+ "is_correct": is_correct,
91
+ "type": entry['type'],
92
+ "model": model_name
93
+ })
94
+
95
+ n_total = len(results)
96
+ n_correct = sum(r['is_correct'] for r in results)
97
+ acc = n_correct / n_total if n_total else float('nan')
98
+ print(f"| > > > > Zero-shot accuracy for {variant}, {model_name}: {n_correct}/{n_total} = {acc:.4f}")
99
+ ```
100
+ </details>
101
+
102
+ ----
103
+ Better attention heatmaps!
104
+
105
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6490359a877fc29cb1b09451/VW0siiegXZeb_Ox5dQTxY.png)
106
+
107
+ ----
108
+
109
+ ## πŸ“Š Benchmark Results πŸš€
110
+
111
+ | Benchmark / Metric | Pre-Trained | Fine-Tuned |
112
+ |------------------------------------|-------------|------------|
113
+ | **Typographic Attack** | | |
114
+ | RTA-100 zero-shot acc. | 0.5560 | **0.7740**πŸŽ–οΈ |
115
+ | BLISS / SCAM NoSCAM acc. | 0.9682 | **0.9759** |
116
+ | BLISS / SCAM SCAM acc. | 0.6627 | **0.7926**πŸŽ–οΈ |
117
+ | BLISS / SCAM SynthSCAM acc. | 0.4320 | **0.6386**πŸŽ–οΈ |
118
+ | **LAION/CLIP_Benchmark** | | |
119
+ | VoC-2007-multilabel mAP | 0.7231 | **0.8335**πŸŽ–οΈ |
120
+ | MSCOCO retrieval image recall@5 | 0.1724 | **0.2523** |
121
+ | MSCOCO retrieval text recall@ | 0.2440 | **0.3569** |
122
+ | xm3600 retrieval image recall@5 | 0.2867 | **0.3874** |
123
+ | xm3600 retrieval text recall@ | 0.2523 | **0.3783** |
124
+ | **ImageNet-1k** | | |
125
+ | zero-shot acc1 | 0.2234 | **0.3193** |
126
+ | zero-shot acc5 | 0.4169 | **0.5555** |
127
+ | mAP | 0.2230 | **0.3185** |
128
+ | **MISC** | | |
129
+ | ImageNet-1k linear probe Top-1 | **53.14%** | 52.65% |
130
+ | ImageNet-1k linear probe Top-5 | 83.41% | **83.48%** |
131
+ | MVT ImageNet/ObjectNet acc. | 0.6492 | **0.7506**πŸŽ–οΈ |
132
+ | Flickr8k Modality Gap: ↓ | 0.8301 | **0.7902** |
133
+ | Flickr8k JSD: ↓ | 0.5225 | **0.2983** |
134
+ | Flickr8k Wasserstein Dist.: ↓ | 0.4573 | **0.4039** |
135
+ | Flickr8k Img-Text Cos Sim (mean): ↑| 0.3164 | **0.3522** |
136
+ | Flickr8k Img-Text Cos Sim (std) | 0.0325 | 0.0537 |
137
+ | Flickr8k Text-Text Cos Sim (mean) | 0.7737 | 0.7561 |
138
+ | Flickr8k Text-Text Cos Sim (std) | 0.1036 | 0.1300 |
139
+