nkaenzig commited on
Commit
f39f378
·
verified ·
1 Parent(s): fdc9fd2

refactored classification embedding logic

Browse files
Files changed (1) hide show
  1. README.md +6 -10
README.md CHANGED
@@ -31,18 +31,14 @@ model = AutoModel.from_pretrained('kaiko-ai/midnight')
31
  ### Extract embeddings for classification
32
  ```python
33
  import torch
34
- from transformers import modeling_outputs
35
- from typing_extensions import override
36
 
37
- # for classification
38
- class ExtractConcatToken:
39
- """Extracts the CLS with Mean Patch tokens from a model output."""
40
- def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
41
- return torch.cat([tensor[:, 0, :], tensor[:, 1:, :].mean(1)], dim=-1)
42
 
43
- extract_embeddings = ExtractConcatToken()
44
- emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
45
- print(f"Embeddings shape: {emb.shape}")
46
  ```
47
 
48
 
 
31
  ### Extract embeddings for classification
32
  ```python
33
  import torch
 
 
34
 
35
+ def extract_embedding(tensor):
36
+ cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :]
37
+ return torch.cat(cls_embedding, patch_embeddings.mean(1), dim=-1)
 
 
38
 
39
+ image_batch = transform(image).unsqueeze(dim=0)
40
+ embedding = extract_embedding(model(image_batch).last_hidden_state)
41
+ print(f"Embeddings shape: {embedding.shape}")
42
  ```
43
 
44