Enable saving in SentenceTransformers by adding get_config_dict (#32)
Browse files- Enable saving in SentenceTransformers by adding get_config_dict (7e1ab6246d86fe02a992b477256aa9f12deb4aea)
Co-authored-by: Edward Ross <[email protected]>
- custom_st.py +3 -0
custom_st.py
CHANGED
|
@@ -160,6 +160,9 @@ class Transformer(nn.Module):
|
|
| 160 |
)
|
| 161 |
return output
|
| 162 |
|
|
|
|
|
|
|
|
|
|
| 163 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
| 164 |
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
|
| 165 |
self.tokenizer.save_pretrained(output_path)
|
|
|
|
| 160 |
)
|
| 161 |
return output
|
| 162 |
|
| 163 |
+
def get_config_dict(self) -> dict[str, Any]:
|
| 164 |
+
return {key: self.__dict__[key] for key in self.config_keys}
|
| 165 |
+
|
| 166 |
def save(self, output_path: str, safe_serialization: bool = True) -> None:
|
| 167 |
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
|
| 168 |
self.tokenizer.save_pretrained(output_path)
|