update
Browse files
README.md
CHANGED
@@ -8,14 +8,78 @@ tags:
|
|
8 |
### Usage
|
9 |
|
10 |
```python
|
11 |
-
import
|
12 |
-
from
|
13 |
-
|
14 |
from huggingface_hub import login
|
|
|
15 |
|
16 |
login() # login or use an access token
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
```
|
|
|
8 |
### Usage
|
9 |
|
10 |
```python
|
11 |
+
from transformers import AutoImageProcessor, AutoModel
|
12 |
+
from PIL import Image
|
13 |
+
import requests
|
14 |
from huggingface_hub import login
|
15 |
+
from torchvision.transforms import v2
|
16 |
|
17 |
login() # login or use an access token
|
18 |
|
19 |
+
# FYI: here a natural image instead of a crop of a WSI for simplicity
|
20 |
+
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
21 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
22 |
+
|
23 |
+
transform = v2.Compose(
|
24 |
+
[
|
25 |
+
v2.Resize(224),
|
26 |
+
v2.CenterCrop(224),
|
27 |
+
v2.ToTensor(),
|
28 |
+
v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
29 |
+
]
|
30 |
+
)
|
31 |
+
model = AutoModel.from_pretrained('kaiko-ai/midnight')
|
32 |
+
```
|
33 |
+
|
34 |
+
|
35 |
+
### Extract embeddings for classification
|
36 |
+
```python
|
37 |
+
import torch
|
38 |
+
from transformers import modeling_outputs
|
39 |
+
from typing_extensions import override
|
40 |
+
|
41 |
+
# for classification
|
42 |
+
class ExtractConcatToken:
|
43 |
+
"""Extracts the CLS with Mean Patch tokens from a model output."""
|
44 |
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
45 |
+
num_reg_tokens = 0
|
46 |
+
return torch.cat([tensor[:, 0, :], tensor[:, 1 + num_reg_tokens :, :].mean(1)], dim=-1)
|
47 |
+
|
48 |
+
extract_embeddings = ExtractConcatToken()
|
49 |
+
emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
|
50 |
+
print(f"Embeddings shape: {emb.shape}")
|
51 |
+
```
|
52 |
+
|
53 |
+
|
54 |
+
### Extract embeddings for segmentation
|
55 |
+
```python
|
56 |
+
import math
|
57 |
+
import torch
|
58 |
+
from transformers import modeling_outputs
|
59 |
+
from typing_extensions import override
|
60 |
+
|
61 |
+
# for segmentation
|
62 |
+
class ExtractPatchFeatures:
|
63 |
+
"""Extracts the patch features from a model output."""
|
64 |
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
65 |
+
"""Call method for the transformation.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
tensor: The raw embeddings of the model.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
|
72 |
+
representing the model output.
|
73 |
+
"""
|
74 |
+
num_reg_tokens = 0
|
75 |
+
num_skip = 1 + num_reg_tokens
|
76 |
+
features = tensor[:, num_skip:, :].permute(0, 2, 1)
|
77 |
+
batch_size, hidden_size, patch_grid = features.shape
|
78 |
+
height = width = int(math.sqrt(patch_grid))
|
79 |
+
assert height * width == patch_grid
|
80 |
+
return features.view(batch_size, hidden_size, height, width)
|
81 |
+
|
82 |
+
extract_embeddings = ExtractPatchFeatures()
|
83 |
+
emb = extract_embeddings(model(transform(image)[None]).last_hidden_state)
|
84 |
+
print(f"Embeddings shape for segmentation: {emb[0].shape}")
|
85 |
```
|