cleanup
Browse files
README.md
CHANGED
@@ -34,10 +34,10 @@ import torch
|
|
34 |
|
35 |
def extract_embedding(tensor):
|
36 |
cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :]
|
37 |
-
return torch.cat(cls_embedding, patch_embeddings.mean(1), dim=-1)
|
38 |
|
39 |
-
|
40 |
-
embedding = extract_embedding(model(
|
41 |
print(f"Embeddings shape: {embedding.shape}")
|
42 |
```
|
43 |
|
@@ -46,26 +46,13 @@ print(f"Embeddings shape: {embedding.shape}")
|
|
46 |
```python
|
47 |
import math
|
48 |
import torch
|
49 |
-
from transformers import modeling_outputs
|
50 |
-
from typing_extensions import override
|
51 |
|
52 |
# for segmentation
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
Args:
|
59 |
-
tensor: The raw embeddings of the model.
|
60 |
-
|
61 |
-
Returns:
|
62 |
-
A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
|
63 |
-
representing the model output.
|
64 |
-
"""
|
65 |
-
features = tensor[:, 1:, :].permute(0, 2, 1)
|
66 |
-
batch_size, hidden_size, patch_grid = features.shape
|
67 |
-
height = width = int(math.sqrt(patch_grid))
|
68 |
-
return features.view(batch_size, hidden_size, height, width)
|
69 |
|
70 |
extract_embeddings = ExtractPatchFeatures()
|
71 |
emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
|
|
|
34 |
|
35 |
def extract_embedding(tensor):
|
36 |
cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :]
|
37 |
+
return torch.cat([cls_embedding, patch_embeddings.mean(1)], dim=-1)
|
38 |
|
39 |
+
batch = transform(image).unsqueeze(dim=0)
|
40 |
+
embedding = extract_embedding(model(batch).last_hidden_state)
|
41 |
print(f"Embeddings shape: {embedding.shape}")
|
42 |
```
|
43 |
|
|
|
46 |
```python
|
47 |
import math
|
48 |
import torch
|
|
|
|
|
49 |
|
50 |
# for segmentation
|
51 |
+
def extract_patch_embeddings(tensor):
|
52 |
+
features = tensor[:, 1:, :].permute(0, 2, 1)
|
53 |
+
batch_size, hidden_size, patch_grid = features.shape
|
54 |
+
height = width = int(math.sqrt(patch_grid))
|
55 |
+
return features.view(batch_size, hidden_size, height, width)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
extract_embeddings = ExtractPatchFeatures()
|
58 |
emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
|