kaiko-ai-user commited on
Commit
4ce0a43
·
verified ·
1 Parent(s): f39f378
Files changed (1) hide show
  1. README.md +8 -21
README.md CHANGED
@@ -34,10 +34,10 @@ import torch
34
 
35
  def extract_embedding(tensor):
36
  cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :]
37
- return torch.cat(cls_embedding, patch_embeddings.mean(1), dim=-1)
38
 
39
- image_batch = transform(image).unsqueeze(dim=0)
40
- embedding = extract_embedding(model(image_batch).last_hidden_state)
41
  print(f"Embeddings shape: {embedding.shape}")
42
  ```
43
 
@@ -46,26 +46,13 @@ print(f"Embeddings shape: {embedding.shape}")
46
  ```python
47
  import math
48
  import torch
49
- from transformers import modeling_outputs
50
- from typing_extensions import override
51
 
52
  # for segmentation
53
- class ExtractPatchFeatures:
54
- """Extracts the patch features from a model output."""
55
- def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
56
- """Call method for the transformation.
57
-
58
- Args:
59
- tensor: The raw embeddings of the model.
60
-
61
- Returns:
62
- A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
63
- representing the model output.
64
- """
65
- features = tensor[:, 1:, :].permute(0, 2, 1)
66
- batch_size, hidden_size, patch_grid = features.shape
67
- height = width = int(math.sqrt(patch_grid))
68
- return features.view(batch_size, hidden_size, height, width)
69
 
70
  extract_embeddings = ExtractPatchFeatures()
71
  emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
 
34
 
35
  def extract_embedding(tensor):
36
  cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :]
37
+ return torch.cat([cls_embedding, patch_embeddings.mean(1)], dim=-1)
38
 
39
+ batch = transform(image).unsqueeze(dim=0)
40
+ embedding = extract_embedding(model(batch).last_hidden_state)
41
  print(f"Embeddings shape: {embedding.shape}")
42
  ```
43
 
 
46
  ```python
47
  import math
48
  import torch
 
 
49
 
50
  # for segmentation
51
+ def extract_patch_embeddings(tensor):
52
+ features = tensor[:, 1:, :].permute(0, 2, 1)
53
+ batch_size, hidden_size, patch_grid = features.shape
54
+ height = width = int(math.sqrt(patch_grid))
55
+ return features.view(batch_size, hidden_size, height, width)
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  extract_embeddings = ExtractPatchFeatures()
58
  emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)