Upload model
Browse files- modeling.py +5 -5
modeling.py
CHANGED
|
@@ -155,12 +155,12 @@ class EntityFusionLayer(nn.Module):
|
|
| 155 |
|
| 156 |
|
| 157 |
class KPRMixin:
|
| 158 |
-
def _forward(self, **inputs:
|
| 159 |
return_dict = inputs.pop("return_dict", True)
|
| 160 |
|
| 161 |
if self.training:
|
| 162 |
-
query_embeddings = self.encode(inputs["queries"])
|
| 163 |
-
passage_embeddings = self.encode(inputs["passages"])
|
| 164 |
|
| 165 |
query_embeddings = self._dist_gather_tensor(query_embeddings)
|
| 166 |
passage_embeddings = self._dist_gather_tensor(passage_embeddings)
|
|
@@ -179,13 +179,13 @@ class KPRMixin:
|
|
| 179 |
return (loss, scores)
|
| 180 |
|
| 181 |
else:
|
| 182 |
-
sentence_embeddings = self.encode(inputs).unsqueeze(1)
|
| 183 |
if return_dict:
|
| 184 |
return ModelOutput(sentence_embeddings=sentence_embeddings)
|
| 185 |
else:
|
| 186 |
return (sentence_embeddings,)
|
| 187 |
|
| 188 |
-
def encode(self, inputs: dict[str, Tensor]) -> Tensor:
|
| 189 |
entity_ids = inputs.pop("entity_ids", None)
|
| 190 |
entity_position_ids = inputs.pop("entity_position_ids", None)
|
| 191 |
entity_embeds = inputs.pop("entity_embeds", None)
|
|
|
|
| 155 |
|
| 156 |
|
| 157 |
class KPRMixin:
|
| 158 |
+
def _forward(self, **inputs: dict[str, Tensor]) -> tuple[Tensor] | tuple[Tensor, Tensor] | ModelOutput:
|
| 159 |
return_dict = inputs.pop("return_dict", True)
|
| 160 |
|
| 161 |
if self.training:
|
| 162 |
+
query_embeddings = self.encode(**inputs["queries"])
|
| 163 |
+
passage_embeddings = self.encode(**inputs["passages"])
|
| 164 |
|
| 165 |
query_embeddings = self._dist_gather_tensor(query_embeddings)
|
| 166 |
passage_embeddings = self._dist_gather_tensor(passage_embeddings)
|
|
|
|
| 179 |
return (loss, scores)
|
| 180 |
|
| 181 |
else:
|
| 182 |
+
sentence_embeddings = self.encode(**inputs).unsqueeze(1)
|
| 183 |
if return_dict:
|
| 184 |
return ModelOutput(sentence_embeddings=sentence_embeddings)
|
| 185 |
else:
|
| 186 |
return (sentence_embeddings,)
|
| 187 |
|
| 188 |
+
def encode(self, **inputs: dict[str, Tensor]) -> Tensor:
|
| 189 |
entity_ids = inputs.pop("entity_ids", None)
|
| 190 |
entity_position_ids = inputs.pop("entity_position_ids", None)
|
| 191 |
entity_embeds = inputs.pop("entity_embeds", None)
|