File size: 2,003 Bytes
ddafb3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import json
import os
import torch
import torch.nn as nn
class TokenPooling(nn.Module):
def __init__(self, dimension: int = 4096) -> None:
super(TokenPooling, self).__init__()
self.dimension = dimension
def forward(
self, features: dict[str, torch.Tensor], **kwargs
) -> dict[str, torch.Tensor]:
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"]
embeddings = self.pool(
last_hidden_state=token_embeddings, attention_mask=attention_mask
)
features["sentence_embedding"] = embeddings
return features
def pool(
self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Here, we take the embedding of the last token from the last layer
"""
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
return last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device).long(),
sequence_lengths.long(),
]
def get_sentence_embedding_dimension(self) -> int:
return self.dimension
def get_config_dict(self) -> dict[str, float]:
return {"dimension": self.dimension}
def save(self, save_dir: str, **kwargs) -> None:
pooling_path = os.path.join(save_dir)
if not os.path.exists(pooling_path):
os.makedirs(pooling_path)
with open(f"{pooling_path}/config.json", "w+") as f:
json.dump(self.get_config_dict(), f, indent=4)
@staticmethod
def load(load_dir: str, **kwargs) -> "TokenPooling":
with open(os.path.join(load_dir, "config.json")) as fIn:
config = json.load(fIn)
return TokenPooling(**config)
|