import torch | |
from torch import nn | |
class DummyPooling(nn.Module): | |
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: | |
return {'sentence_embedding': features['token_embeddings']} | |
def save(self, save_dir: str, **kwargs) -> None: | |
pass | |
def load(load_dir: str, **kwargs) -> "DummyPooling": | |
return DummyPooling() |