import torch from torch import nn class DummyPooling(nn.Module): def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: return {'sentence_embedding': features['token_embeddings']} def save(self, save_dir: str, **kwargs) -> None: pass def load(load_dir: str, **kwargs) -> "DummyPooling": return DummyPooling()