ekolodin commited on
Commit
7cabed8
·
verified ·
1 Parent(s): 828bc41

Upload dummy_pooling.py

Browse files
Files changed (1) hide show
  1. dummy_pooling.py +14 -0
dummy_pooling.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+
5
+
6
+ class DummyPooling(nn.Module):
7
+ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
8
+ return {'sentence_embedding': features['token_embeddings']}
9
+
10
+ def save(self, save_dir: str, **kwargs) -> None:
11
+ pass
12
+
13
+ def load(load_dir: str, **kwargs) -> "DummyPooling":
14
+ return DummyPooling()