elbedding-v2-autogptq-int8 / token_pooling.py
vijusudhi's picture
Upload folder using huggingface_hub
ddafb3c verified
import json
import os
import torch
import torch.nn as nn
class TokenPooling(nn.Module):
def __init__(self, dimension: int = 4096) -> None:
super(TokenPooling, self).__init__()
self.dimension = dimension
def forward(
self, features: dict[str, torch.Tensor], **kwargs
) -> dict[str, torch.Tensor]:
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"]
embeddings = self.pool(
last_hidden_state=token_embeddings, attention_mask=attention_mask
)
features["sentence_embedding"] = embeddings
return features
def pool(
self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Here, we take the embedding of the last token from the last layer
"""
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
return last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device).long(),
sequence_lengths.long(),
]
def get_sentence_embedding_dimension(self) -> int:
return self.dimension
def get_config_dict(self) -> dict[str, float]:
return {"dimension": self.dimension}
def save(self, save_dir: str, **kwargs) -> None:
pooling_path = os.path.join(save_dir)
if not os.path.exists(pooling_path):
os.makedirs(pooling_path)
with open(f"{pooling_path}/config.json", "w+") as f:
json.dump(self.get_config_dict(), f, indent=4)
@staticmethod
def load(load_dir: str, **kwargs) -> "TokenPooling":
with open(os.path.join(load_dir, "config.json")) as fIn:
config = json.load(fIn)
return TokenPooling(**config)