bwang0911 commited on
Commit
8741cc8
1 Parent(s): 3a2c2be

support sentence-transformers

Browse files
Files changed (3) hide show
  1. config_sentence_transformers.json +10 -0
  2. custom_st.py +189 -0
  3. modules.json +21 -0
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.1.0.dev0",
4
+ "transformers": "4.41.2",
5
+ "pytorch": "2.3.1+cu121"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": "cosine"
10
+ }
custom_st.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from io import BytesIO
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import requests
8
+ import torch
9
+ from PIL import Image
10
+ from torch import nn
11
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer
12
+
13
+
14
+ class Transformer(nn.Module):
15
+ """Huggingface AutoModel to generate token embeddings.
16
+ Loads the correct class, e.g. BERT / RoBERTa etc.
17
+
18
+ Args:
19
+ model_name_or_path: Huggingface models name
20
+ (https://huggingface.co/models)
21
+ max_seq_length: Truncate any inputs longer than max_seq_length
22
+ model_args: Keyword arguments passed to the Huggingface
23
+ Transformers model
24
+ tokenizer_args: Keyword arguments passed to the Huggingface
25
+ Transformers tokenizer
26
+ config_args: Keyword arguments passed to the Huggingface
27
+ Transformers config
28
+ cache_dir: Cache dir for Huggingface Transformers to store/load
29
+ models
30
+ do_lower_case: If true, lowercases the input (independent if the
31
+ model is cased or not)
32
+ tokenizer_name_or_path: Name or path of the tokenizer. When
33
+ None, then model_name_or_path is used
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_name_or_path: str,
39
+ max_seq_length: int | None = None,
40
+ model_args: dict[str, Any] | None = None,
41
+ tokenizer_args: dict[str, Any] | None = None,
42
+ config_args: dict[str, Any] | None = None,
43
+ cache_dir: str | None = None,
44
+ do_lower_case: bool = False,
45
+ tokenizer_name_or_path: str = None,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.config_keys = ["max_seq_length", "do_lower_case"]
49
+ self.do_lower_case = do_lower_case
50
+ if model_args is None:
51
+ model_args = {}
52
+ if tokenizer_args is None:
53
+ tokenizer_args = {}
54
+ if config_args is None:
55
+ config_args = {}
56
+
57
+ config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
58
+ self._load_model(model_name_or_path, config, cache_dir, **model_args)
59
+
60
+ if max_seq_length is not None and "model_max_length" not in tokenizer_args:
61
+ tokenizer_args["model_max_length"] = max_seq_length
62
+ self.tokenizer = AutoTokenizer.from_pretrained(
63
+ tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
64
+ cache_dir=cache_dir,
65
+ **tokenizer_args,
66
+ )
67
+
68
+ # No max_seq_length set. Try to infer from model
69
+ if max_seq_length is None:
70
+ if (
71
+ hasattr(self.auto_model, "config")
72
+ and hasattr(self.auto_model.config, "max_position_embeddings")
73
+ and hasattr(self.tokenizer, "model_max_length")
74
+ ):
75
+ max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
76
+
77
+ self.max_seq_length = max_seq_length
78
+
79
+ if tokenizer_name_or_path is not None:
80
+ self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
81
+
82
+ def forward(
83
+ self, features: Dict[str, torch.Tensor], task_type: Optional[str] = None
84
+ ) -> Dict[str, torch.Tensor]:
85
+ """Returns token_embeddings, cls_token"""
86
+ if task_type and task_type not in self._lora_adaptations:
87
+ raise ValueError(
88
+ f"Unsupported task '{task_type}'. "
89
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
90
+ f"Alternatively, don't pass the `task_type` argument to disable LoRA."
91
+ )
92
+
93
+ adapter_mask = None
94
+ if task_type:
95
+ task_id = self._adaptation_map[task_type]
96
+ num_examples = 1
97
+ if isinstance(features['input_ids'][0], list):
98
+ # If input_ids[0] is a list, it means multiple inputs (list of texts)
99
+ num_examples = len(features['input_ids'])
100
+
101
+ adapter_mask = torch.full(
102
+ (num_examples,), task_id, dtype=torch.int32, device=self.device
103
+ )
104
+
105
+ lora_arguments = (
106
+ {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
107
+ )
108
+ output_states = self.forward(**features, **lora_arguments, return_dict=False)
109
+ output_tokens = output_states[0]
110
+ features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
111
+ return features
112
+
113
+ def get_word_embedding_dimension(self) -> int:
114
+ return self.auto_model.config.hidden_size
115
+
116
+ def tokenize(
117
+ self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
118
+ ) -> dict[str, torch.Tensor]:
119
+ """Tokenizes a text and maps tokens to token-ids"""
120
+ output = {}
121
+ if isinstance(texts[0], str):
122
+ to_tokenize = [texts]
123
+ elif isinstance(texts[0], dict):
124
+ to_tokenize = []
125
+ output["text_keys"] = []
126
+ for lookup in texts:
127
+ text_key, text = next(iter(lookup.items()))
128
+ to_tokenize.append(text)
129
+ output["text_keys"].append(text_key)
130
+ to_tokenize = [to_tokenize]
131
+ else:
132
+ batch1, batch2 = [], []
133
+ for text_tuple in texts:
134
+ batch1.append(text_tuple[0])
135
+ batch2.append(text_tuple[1])
136
+ to_tokenize = [batch1, batch2]
137
+
138
+ # strip
139
+ to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
140
+
141
+ # Lowercase
142
+ if self.do_lower_case:
143
+ to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
144
+
145
+ output.update(
146
+ self.tokenizer(
147
+ *to_tokenize,
148
+ padding=padding,
149
+ truncation="longest_first",
150
+ return_tensors="pt",
151
+ max_length=self.max_seq_length,
152
+ )
153
+ )
154
+ return output
155
+
156
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
157
+ self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
158
+ self.tokenizer.save_pretrained(output_path)
159
+
160
+ with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
161
+ json.dump(self.get_config_dict(), fOut, indent=2)
162
+
163
+
164
+ @classmethod
165
+ def load(cls, input_path: str) -> "Transformer":
166
+ # Old classes used other config names than 'sentence_bert_config.json'
167
+ for config_name in [
168
+ "sentence_bert_config.json",
169
+ "sentence_roberta_config.json",
170
+ "sentence_distilbert_config.json",
171
+ "sentence_camembert_config.json",
172
+ "sentence_albert_config.json",
173
+ "sentence_xlm-roberta_config.json",
174
+ "sentence_xlnet_config.json",
175
+ ]:
176
+ sbert_config_path = os.path.join(input_path, config_name)
177
+ if os.path.exists(sbert_config_path):
178
+ break
179
+
180
+ with open(sbert_config_path) as fIn:
181
+ config = json.load(fIn)
182
+ # Don't allow configs to set trust_remote_code
183
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
184
+ config["model_args"].pop("trust_remote_code")
185
+ if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
186
+ config["tokenizer_args"].pop("trust_remote_code")
187
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
188
+ config["config_args"].pop("trust_remote_code")
189
+ return cls(model_name_or_path=input_path, **config)
modules.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.Transformer",
7
+ "kwargs": ["task_type"]
8
+ },
9
+ {
10
+ "idx": 1,
11
+ "name": "1",
12
+ "path": "1_Pooling",
13
+ "type": "sentence_transformers.models.Pooling"
14
+ },
15
+ {
16
+ "idx": 2,
17
+ "name": "2",
18
+ "path": "2_Normalize",
19
+ "type": "sentence_transformers.models.Normalize"
20
+ }
21
+ ]