Update README.md
Browse files
README.md
CHANGED
@@ -68,27 +68,25 @@ This will install an editable version of repo, allowing you to make changes to t
|
|
68 |
## Image and Text Feature extraction with a Trained Model
|
69 |
```python
|
70 |
import torch
|
71 |
-
from core.vision_encoder.factory import create_model_and_transforms, get_tokenizer
|
72 |
from PIL import Image
|
|
|
|
|
73 |
|
74 |
-
|
75 |
-
|
76 |
|
77 |
-
model
|
78 |
-
model_name,
|
79 |
-
pretrained=pretrained,
|
80 |
-
)
|
81 |
model = model.cuda()
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
84 |
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
|
85 |
|
86 |
with torch.no_grad(), torch.autocast("cuda"):
|
87 |
-
image_features = model
|
88 |
-
|
89 |
-
image_features /= image_features.norm(dim=-1, keepdim=True)
|
90 |
-
text_features /= text_features.norm(dim=-1, keepdim=True)
|
91 |
-
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
92 |
|
93 |
print("Label probs:", text_probs) # prints: [[0.0, 0.0, 1.0]]
|
94 |
```
|
|
|
68 |
## Image and Text Feature extraction with a Trained Model
|
69 |
```python
|
70 |
import torch
|
|
|
71 |
from PIL import Image
|
72 |
+
import core.vision_encoder.pe as pe
|
73 |
+
import core.vision_encoder.transforms as transforms
|
74 |
|
75 |
+
print("CLIP configs:", pe.CLIP.available_configs())
|
76 |
+
# CLIP configs: ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224']
|
77 |
|
78 |
+
model = pe.CLIP.from_config("PE-Core-B16-224", pretrained=True) # Downloads from HF
|
|
|
|
|
|
|
79 |
model = model.cuda()
|
80 |
+
|
81 |
+
preprocess = transforms.get_image_transform(model.image_size)
|
82 |
+
tokenizer = transforms.get_text_tokenizer(model.context_length)
|
83 |
+
|
84 |
+
image = preprocess(Image.open("docs/assets/cat.png")).unsqueeze(0).cuda()
|
85 |
text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
|
86 |
|
87 |
with torch.no_grad(), torch.autocast("cuda"):
|
88 |
+
image_features, text_features, logit_scale = model(image, text)
|
89 |
+
text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
|
|
|
|
|
|
|
90 |
|
91 |
print("Label probs:", text_probs) # prints: [[0.0, 0.0, 1.0]]
|
92 |
```
|