refactored classification embedding logic
Browse files
README.md
CHANGED
@@ -31,18 +31,14 @@ model = AutoModel.from_pretrained('kaiko-ai/midnight')
|
|
31 |
### Extract embeddings for classification
|
32 |
```python
|
33 |
import torch
|
34 |
-
from transformers import modeling_outputs
|
35 |
-
from typing_extensions import override
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
41 |
-
return torch.cat([tensor[:, 0, :], tensor[:, 1:, :].mean(1)], dim=-1)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
print(f"Embeddings shape: {
|
46 |
```
|
47 |
|
48 |
|
|
|
31 |
### Extract embeddings for classification
|
32 |
```python
|
33 |
import torch
|
|
|
|
|
34 |
|
35 |
+
def extract_embedding(tensor):
|
36 |
+
cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :]
|
37 |
+
return torch.cat(cls_embedding, patch_embeddings.mean(1), dim=-1)
|
|
|
|
|
38 |
|
39 |
+
image_batch = transform(image).unsqueeze(dim=0)
|
40 |
+
embedding = extract_embedding(model(image_batch).last_hidden_state)
|
41 |
+
print(f"Embeddings shape: {embedding.shape}")
|
42 |
```
|
43 |
|
44 |
|