Giga-Retrieval-instruct / dummy_pooling.py
ekolodin's picture
Upload dummy_pooling.py
7cabed8 verified
import torch
from torch import nn
class DummyPooling(nn.Module):
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
return {'sentence_embedding': features['token_embeddings']}
def save(self, save_dir: str, **kwargs) -> None:
pass
def load(load_dir: str, **kwargs) -> "DummyPooling":
return DummyPooling()