karasikov commited on
Commit
1494ba0
·
1 Parent(s): 32bd926
Files changed (1) hide show
  1. README.md +70 -6
README.md CHANGED
@@ -8,14 +8,78 @@ tags:
8
  ### Usage
9
 
10
  ```python
11
- import timm
12
- from timm.data import resolve_data_config
13
- from timm.data.transforms_factory import create_transform
14
  from huggingface_hub import login
 
15
 
16
  login() # login or use an access token
17
 
18
- model = timm.create_model("hf-hub:kaiko-ai/midnight", pretrained=True)
19
- transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
20
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ```
 
8
  ### Usage
9
 
10
  ```python
11
+ from transformers import AutoImageProcessor, AutoModel
12
+ from PIL import Image
13
+ import requests
14
  from huggingface_hub import login
15
+ from torchvision.transforms import v2
16
 
17
  login() # login or use an access token
18
 
19
+ # FYI: here a natural image instead of a crop of a WSI for simplicity
20
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
21
+ image = Image.open(requests.get(url, stream=True).raw)
22
+
23
+ transform = v2.Compose(
24
+ [
25
+ v2.Resize(224),
26
+ v2.CenterCrop(224),
27
+ v2.ToTensor(),
28
+ v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
29
+ ]
30
+ )
31
+ model = AutoModel.from_pretrained('kaiko-ai/midnight')
32
+ ```
33
+
34
+
35
+ ### Extract embeddings for classification
36
+ ```python
37
+ import torch
38
+ from transformers import modeling_outputs
39
+ from typing_extensions import override
40
+
41
+ # for classification
42
+ class ExtractConcatToken:
43
+ """Extracts the CLS with Mean Patch tokens from a model output."""
44
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
45
+ num_reg_tokens = 0
46
+ return torch.cat([tensor[:, 0, :], tensor[:, 1 + num_reg_tokens :, :].mean(1)], dim=-1)
47
+
48
+ extract_embeddings = ExtractConcatToken()
49
+ emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
50
+ print(f"Embeddings shape: {emb.shape}")
51
+ ```
52
+
53
+
54
+ ### Extract embeddings for segmentation
55
+ ```python
56
+ import math
57
+ import torch
58
+ from transformers import modeling_outputs
59
+ from typing_extensions import override
60
+
61
+ # for segmentation
62
+ class ExtractPatchFeatures:
63
+ """Extracts the patch features from a model output."""
64
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
65
+ """Call method for the transformation.
66
+
67
+ Args:
68
+ tensor: The raw embeddings of the model.
69
+
70
+ Returns:
71
+ A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
72
+ representing the model output.
73
+ """
74
+ num_reg_tokens = 0
75
+ num_skip = 1 + num_reg_tokens
76
+ features = tensor[:, num_skip:, :].permute(0, 2, 1)
77
+ batch_size, hidden_size, patch_grid = features.shape
78
+ height = width = int(math.sqrt(patch_grid))
79
+ assert height * width == patch_grid
80
+ return features.view(batch_size, hidden_size, height, width)
81
+
82
+ extract_embeddings = ExtractPatchFeatures()
83
+ emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
84
+ print(f"Embeddings shape for segmentation: {emb[0].shape}")
85
  ```