|
import json |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class TokenPooling(nn.Module): |
|
def __init__(self, dimension: int = 4096) -> None: |
|
super(TokenPooling, self).__init__() |
|
self.dimension = dimension |
|
|
|
def forward( |
|
self, features: dict[str, torch.Tensor], **kwargs |
|
) -> dict[str, torch.Tensor]: |
|
token_embeddings = features["token_embeddings"] |
|
attention_mask = features["attention_mask"] |
|
|
|
embeddings = self.pool( |
|
last_hidden_state=token_embeddings, attention_mask=attention_mask |
|
) |
|
features["sentence_embedding"] = embeddings |
|
return features |
|
|
|
def pool( |
|
self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Here, we take the embedding of the last token from the last layer |
|
""" |
|
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
|
if left_padding: |
|
return last_hidden_state[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_state.shape[0] |
|
return last_hidden_state[ |
|
torch.arange(batch_size, device=last_hidden_state.device).long(), |
|
sequence_lengths.long(), |
|
] |
|
|
|
def get_sentence_embedding_dimension(self) -> int: |
|
return self.dimension |
|
|
|
def get_config_dict(self) -> dict[str, float]: |
|
return {"dimension": self.dimension} |
|
|
|
def save(self, save_dir: str, **kwargs) -> None: |
|
pooling_path = os.path.join(save_dir) |
|
if not os.path.exists(pooling_path): |
|
os.makedirs(pooling_path) |
|
|
|
with open(f"{pooling_path}/config.json", "w+") as f: |
|
json.dump(self.get_config_dict(), f, indent=4) |
|
|
|
@staticmethod |
|
def load(load_dir: str, **kwargs) -> "TokenPooling": |
|
with open(os.path.join(load_dir, "config.json")) as fIn: |
|
config = json.load(fIn) |
|
return TokenPooling(**config) |
|
|