Spaces:
Running
Running
Major refactoring and added test cases
Browse files- .gitignore +1 -0
- encoder_models.py +108 -0
- semf1.py +74 -69
- tests.py +179 -17
- type_aliases.py +10 -0
- utils.py +78 -10
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
encoder_models.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
from numpy.typing import NDArray
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
|
7 |
+
from type_aliases import ENCODER_DEVICE_TYPE
|
8 |
+
|
9 |
+
|
10 |
+
class Encoder(abc.ABC):
|
11 |
+
@abc.abstractmethod
|
12 |
+
def encode(self, prediction: List[str]) -> NDArray:
|
13 |
+
"""
|
14 |
+
Abstract method to encode a list of sentences into sentence embeddings.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
prediction (List[str]): List of sentences to encode.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
21 |
+
|
22 |
+
Raises:
|
23 |
+
NotImplementedError: If the method is not implemented in the subclass.
|
24 |
+
"""
|
25 |
+
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
|
26 |
+
|
27 |
+
|
28 |
+
class USE(Encoder):
|
29 |
+
def __init__(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
def encode(self, prediction: List[str]) -> NDArray:
|
33 |
+
pass
|
34 |
+
|
35 |
+
|
36 |
+
class SBertEncoder(Encoder):
|
37 |
+
def __init__(self, model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
|
38 |
+
"""
|
39 |
+
Initialize SBertEncoder instance.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
model_name (str): Name or path of the Sentence Transformer model.
|
43 |
+
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
|
44 |
+
batch_size (int): Batch size for encoding.
|
45 |
+
verbose (bool): Whether to print verbose information during encoding.
|
46 |
+
"""
|
47 |
+
self.model = SentenceTransformer(model_name)
|
48 |
+
self.device = device
|
49 |
+
self.batch_size = batch_size
|
50 |
+
self.verbose = verbose
|
51 |
+
|
52 |
+
def encode(self, prediction: List[str]) -> NDArray:
|
53 |
+
"""
|
54 |
+
Encode a list of sentences into sentence embeddings.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
prediction (List[str]): List of sentences to encode.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
61 |
+
"""
|
62 |
+
|
63 |
+
# SBert output is always Batch x Dim
|
64 |
+
if isinstance(self.device, list):
|
65 |
+
# Use multiprocess encoding for list of devices
|
66 |
+
pool = self.model.start_multi_process_pool(target_devices=self.device)
|
67 |
+
embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
|
68 |
+
self.model.stop_multi_process_pool(pool)
|
69 |
+
else:
|
70 |
+
# Single device encoding
|
71 |
+
embeddings = self.model.encode(
|
72 |
+
prediction,
|
73 |
+
device=self.device,
|
74 |
+
batch_size=self.batch_size,
|
75 |
+
show_progress_bar=self.verbose,
|
76 |
+
)
|
77 |
+
|
78 |
+
return embeddings
|
79 |
+
|
80 |
+
|
81 |
+
def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool) -> Encoder:
|
82 |
+
"""
|
83 |
+
Get the encoder instance based on the specified model name.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
model_name (str): Name of the model to instantiate
|
87 |
+
Options: [pv1, stsb, use]
|
88 |
+
pv1 - paraphrase-distilroberta-base-v1 (Default)
|
89 |
+
stsb - stsb-roberta-large
|
90 |
+
use - Universal Sentence Encoder
|
91 |
+
device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
|
92 |
+
(e.g., "cuda", 0 for GPU, "cpu").
|
93 |
+
batch_size (int): Batch size for encoding.
|
94 |
+
verbose (bool): Whether to print verbose information during encoder initialization.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Encoder: Instance of the selected encoder based on the model_name.
|
98 |
+
|
99 |
+
Raises:
|
100 |
+
ValueError: If an unsupported model_name is provided.
|
101 |
+
"""
|
102 |
+
|
103 |
+
# TODO: chnage this when changing the TF model
|
104 |
+
if model_name == "use":
|
105 |
+
return SBertEncoder("sentence-transformers/use-cmlm-multilingual", device, batch_size, verbose)
|
106 |
+
# return USE()
|
107 |
+
else:
|
108 |
+
return SBertEncoder(model_name, device, batch_size, verbose)
|
semf1.py
CHANGED
@@ -14,21 +14,19 @@
|
|
14 |
# TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
|
15 |
"""Sem-F1 metric"""
|
16 |
|
17 |
-
import
|
18 |
-
import
|
19 |
-
from typing import List, Optional, Tuple, Union
|
20 |
|
21 |
import datasets
|
22 |
import evaluate
|
23 |
import nltk
|
24 |
import numpy as np
|
25 |
from numpy.typing import NDArray
|
26 |
-
from sentence_transformers import SentenceTransformer
|
27 |
from sklearn.metrics.pairwise import cosine_similarity
|
28 |
-
import torch
|
29 |
-
from tqdm import tqdm
|
30 |
|
31 |
-
from
|
|
|
|
|
32 |
|
33 |
_CITATION = """\
|
34 |
@inproceedings{bansal-etal-2022-sem,
|
@@ -123,80 +121,80 @@ Examples:
|
|
123 |
[0.77, 0.56]
|
124 |
"""
|
125 |
|
126 |
-
_PREDICTION_TYPE = Union[List[str], List[List[str]]]
|
127 |
-
_REFERENCE_TYPE = Union[List[str], List[List[str]], List[List[List[str]]]]
|
128 |
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
pass
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
pass
|
139 |
|
140 |
-
|
141 |
-
|
142 |
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
def __init__(self, model_name: str, device: Union[str, int], batch_size: int):
|
146 |
-
self.model = SentenceTransformer(model_name)
|
147 |
-
self.device = device
|
148 |
-
self.batch_size = batch_size
|
149 |
|
150 |
-
def encode(self, prediction: List[str]) -> NDArray:
|
151 |
-
"""Returns sentence embeddings of dim: Batch x Dim"""
|
152 |
-
# SBert output is always Batch x Dim
|
153 |
-
return self.model.encode(prediction, device=self.device, batch_size=self.batch_size)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
return SBertEncoder(model_name, device, batch_size)
|
162 |
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
|
|
|
|
|
|
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
if gpu_available:
|
175 |
-
gpu_count = torch.cuda.device_count()
|
176 |
-
if isinstance(gpu, int) and gpu >= gpu_count:
|
177 |
-
raise ValueError(
|
178 |
-
f"There are {gpu_count} gpus available. Provide the correct gpu index. You provided: {gpu}"
|
179 |
-
)
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
elif gpu is True and gpu_available:
|
185 |
-
device = 0 # by default run on device 0
|
186 |
-
elif isinstance(gpu, int):
|
187 |
-
device = gpu
|
188 |
-
else: # This will never happen
|
189 |
-
raise ValueError(f"gpu must be bool or int. Provided value: {gpu}")
|
190 |
|
191 |
-
|
|
|
192 |
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
-
tokenize_sentences: bool,
|
196 |
-
multi_references: bool,
|
197 |
-
predictions: _PREDICTION_TYPE,
|
198 |
-
references: _REFERENCE_TYPE,
|
199 |
-
):
|
200 |
if tokenize_sentences and multi_references:
|
201 |
condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
|
202 |
elif not tokenize_sentences and multi_references:
|
@@ -215,7 +213,7 @@ class SemF1(evaluate.Metric):
|
|
215 |
_MODEL_TYPE_TO_NAME = {
|
216 |
"pv1": "paraphrase-distilroberta-base-v1",
|
217 |
"stsb": "stsb-roberta-large",
|
218 |
-
"use": "sentence-transformers/use-cmlm-multilingual", # TODO: check PyTorch USE VS TF USE
|
219 |
}
|
220 |
|
221 |
def _info(self):
|
@@ -275,7 +273,7 @@ class SemF1(evaluate.Metric):
|
|
275 |
|
276 |
def _get_model_name(self, model_type: Optional[str] = None) -> str:
|
277 |
if model_type is None:
|
278 |
-
model_type = "
|
279 |
|
280 |
if model_type not in self._MODEL_TYPE_TO_NAME.keys():
|
281 |
raise ValueError(f"Provide a correct model_type.\n"
|
@@ -291,7 +289,6 @@ class SemF1(evaluate.Metric):
|
|
291 |
# if not nltk.data.find("tokenizers/punkt"): # TODO: check why it is not working
|
292 |
# pass
|
293 |
|
294 |
-
|
295 |
def _compute(
|
296 |
self,
|
297 |
predictions,
|
@@ -299,8 +296,9 @@ class SemF1(evaluate.Metric):
|
|
299 |
model_type: Optional[str] = None,
|
300 |
tokenize_sentences: bool = True,
|
301 |
multi_references: bool = False,
|
302 |
-
gpu:
|
303 |
batch_size: int = 32,
|
|
|
304 |
) -> List[Scores]:
|
305 |
"""
|
306 |
Compute precision, recall, and F1 scores for given predictions and references.
|
@@ -308,10 +306,15 @@ class SemF1(evaluate.Metric):
|
|
308 |
:param predictions
|
309 |
:param references
|
310 |
:param model_type: Type of model to use for encoding.
|
|
|
|
|
|
|
|
|
311 |
:param tokenize_sentences: Flag to sentence tokenize the document.
|
312 |
:param multi_references: Flag to indicate multiple references.
|
313 |
:param gpu: GPU device to use.
|
314 |
:param batch_size: Batch size for encoding.
|
|
|
315 |
|
316 |
:return: List of Scores dataclass with precision, recall, and F1 scores.
|
317 |
"""
|
@@ -320,11 +323,13 @@ class SemF1(evaluate.Metric):
|
|
320 |
_validate_input_format(tokenize_sentences, multi_references, predictions, references)
|
321 |
|
322 |
# Get GPU
|
323 |
-
device =
|
|
|
|
|
324 |
|
325 |
# Get the encoder model
|
326 |
model_name = self._get_model_name(model_type)
|
327 |
-
encoder =
|
328 |
|
329 |
# We'll handle the single reference and multi-reference case same way. So change the data format accordingly
|
330 |
if not multi_references:
|
|
|
14 |
# TODO: Add test cases, Remove tokenize_sentences flag since it can be determined from the input itself.
|
15 |
"""Sem-F1 metric"""
|
16 |
|
17 |
+
from functools import partial
|
18 |
+
from typing import List, Optional, Tuple
|
|
|
19 |
|
20 |
import datasets
|
21 |
import evaluate
|
22 |
import nltk
|
23 |
import numpy as np
|
24 |
from numpy.typing import NDArray
|
|
|
25 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
26 |
|
27 |
+
from encoder_models import get_encoder
|
28 |
+
from type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
|
29 |
+
from utils import is_nested_list_of_type, Scores, slice_embeddings, flatten_list, get_gpu
|
30 |
|
31 |
_CITATION = """\
|
32 |
@inproceedings{bansal-etal-2022-sem,
|
|
|
121 |
[0.77, 0.56]
|
122 |
"""
|
123 |
|
|
|
|
|
124 |
|
125 |
+
def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tuple[float, float]:
|
126 |
+
"""
|
127 |
+
Compute precision and recall based on cosine similarity between predicted and reference embeddings.
|
128 |
|
129 |
+
Args:
|
130 |
+
pred_embeds (NDArray): Predicted embeddings (shape: [num_pred, embedding_dim]).
|
131 |
+
ref_embeds (NDArray): Reference embeddings (shape: [num_ref, embedding_dim]).
|
|
|
132 |
|
133 |
+
Returns:
|
134 |
+
Tuple[float, float]: Precision and recall based on cosine similarity scores.
|
135 |
+
Precision: Average maximum cosine similarity score per predicted embedding.
|
136 |
+
Recall: Average maximum cosine similarity score per reference embedding.
|
137 |
+
"""
|
138 |
+
# Compute cosine similarity between predicted and reference embeddings
|
139 |
+
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
140 |
|
141 |
+
# Compute precision per predicted embedding
|
142 |
+
precision_per_sentence_sim = np.max(cosine_scores, axis=-1)
|
|
|
143 |
|
144 |
+
# Compute recall per reference embedding
|
145 |
+
recall_per_sentence_sim = np.max(cosine_scores, axis=0)
|
146 |
|
147 |
+
# Calculate mean precision and recall scores
|
148 |
+
precision = np.mean(precision_per_sentence_sim).item()
|
149 |
+
recall = np.mean(recall_per_sentence_sim).item()
|
150 |
|
151 |
+
return precision, recall
|
|
|
|
|
|
|
|
|
152 |
|
|
|
|
|
|
|
|
|
153 |
|
154 |
+
def _validate_input_format(
|
155 |
+
tokenize_sentences: bool,
|
156 |
+
multi_references: bool,
|
157 |
+
predictions: PREDICTION_TYPE,
|
158 |
+
references: REFERENCE_TYPE,
|
159 |
+
):
|
160 |
+
"""
|
161 |
+
Validate the format of predictions and references based on specified criteria.
|
162 |
|
163 |
+
Args:
|
164 |
+
- tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
|
165 |
+
- multi_references (bool): Flag indicating whether multiple references are provided.
|
166 |
+
- predictions (PREDICTION_TYPE): Predictions to validate.
|
167 |
+
- references (REFERENCE_TYPE): References to validate.
|
|
|
168 |
|
169 |
+
Raises:
|
170 |
+
- ValueError: If the format of predictions or references does not meet the specified criteria.
|
171 |
|
172 |
+
Validation Criteria:
|
173 |
+
The function validates predictions and references based on the following conditions:
|
174 |
+
1. If `tokenize_sentences` is True and `multi_references` is True:
|
175 |
+
- Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
|
176 |
+
- References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
|
177 |
|
178 |
+
2. If `tokenize_sentences` is False and `multi_references` is True:
|
179 |
+
- Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
|
180 |
+
- References must be a list of list of list of strings (`is_list_of_strings_at_depth(references, 3)`).
|
181 |
|
182 |
+
3. If `tokenize_sentences` is True and `multi_references` is False:
|
183 |
+
- Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
|
184 |
+
- References must be a list of strings (`is_list_of_strings_at_depth(references, 1)`).
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
+
4. If `tokenize_sentences` is False and `multi_references` is False:
|
187 |
+
- Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
|
188 |
+
- References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
The function checks these conditions and raises a ValueError if any condition is not met,
|
191 |
+
indicating that predictions or references are not in the valid input format.
|
192 |
|
193 |
+
Note:
|
194 |
+
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
195 |
+
"""
|
196 |
|
197 |
+
is_list_of_strings_at_depth = partial(is_nested_list_of_type, element_type=str)
|
|
|
|
|
|
|
|
|
|
|
198 |
if tokenize_sentences and multi_references:
|
199 |
condition = is_list_of_strings_at_depth(predictions, 1) and is_list_of_strings_at_depth(references, 2)
|
200 |
elif not tokenize_sentences and multi_references:
|
|
|
213 |
_MODEL_TYPE_TO_NAME = {
|
214 |
"pv1": "paraphrase-distilroberta-base-v1",
|
215 |
"stsb": "stsb-roberta-large",
|
216 |
+
"use": "use", # "sentence-transformers/use-cmlm-multilingual", # TODO: check PyTorch USE VS TF USE
|
217 |
}
|
218 |
|
219 |
def _info(self):
|
|
|
273 |
|
274 |
def _get_model_name(self, model_type: Optional[str] = None) -> str:
|
275 |
if model_type is None:
|
276 |
+
model_type = "use"
|
277 |
|
278 |
if model_type not in self._MODEL_TYPE_TO_NAME.keys():
|
279 |
raise ValueError(f"Provide a correct model_type.\n"
|
|
|
289 |
# if not nltk.data.find("tokenizers/punkt"): # TODO: check why it is not working
|
290 |
# pass
|
291 |
|
|
|
292 |
def _compute(
|
293 |
self,
|
294 |
predictions,
|
|
|
296 |
model_type: Optional[str] = None,
|
297 |
tokenize_sentences: bool = True,
|
298 |
multi_references: bool = False,
|
299 |
+
gpu: DEVICE_TYPE = False,
|
300 |
batch_size: int = 32,
|
301 |
+
verbose: bool = False,
|
302 |
) -> List[Scores]:
|
303 |
"""
|
304 |
Compute precision, recall, and F1 scores for given predictions and references.
|
|
|
306 |
:param predictions
|
307 |
:param references
|
308 |
:param model_type: Type of model to use for encoding.
|
309 |
+
Options: [pv1, stsb, use]
|
310 |
+
pv1 - paraphrase-distilroberta-base-v1 (Default)
|
311 |
+
stsb - stsb-roberta-large
|
312 |
+
use - Universal Sentence Encoder
|
313 |
:param tokenize_sentences: Flag to sentence tokenize the document.
|
314 |
:param multi_references: Flag to indicate multiple references.
|
315 |
:param gpu: GPU device to use.
|
316 |
:param batch_size: Batch size for encoding.
|
317 |
+
:param verbose: Flag to indicate verbose output.
|
318 |
|
319 |
:return: List of Scores dataclass with precision, recall, and F1 scores.
|
320 |
"""
|
|
|
323 |
_validate_input_format(tokenize_sentences, multi_references, predictions, references)
|
324 |
|
325 |
# Get GPU
|
326 |
+
device = get_gpu(gpu)
|
327 |
+
if verbose:
|
328 |
+
print(f"Using devices: {device}")
|
329 |
|
330 |
# Get the encoder model
|
331 |
model_name = self._get_model_name(model_type)
|
332 |
+
encoder = get_encoder(model_name, device=device, batch_size=batch_size, verbose=verbose)
|
333 |
|
334 |
# We'll handle the single reference and multi-reference case same way. So change the data format accordingly
|
335 |
if not multi_references:
|
tests.py
CHANGED
@@ -1,17 +1,179 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import statistics
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
|
8 |
+
from encoder_models import SBertEncoder, get_encoder
|
9 |
+
from utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
|
10 |
+
|
11 |
+
|
12 |
+
class TestUtils(unittest.TestCase):
|
13 |
+
def test_get_gpu(self):
|
14 |
+
gpu_count = torch.cuda.device_count()
|
15 |
+
gpu_available = torch.cuda.is_available()
|
16 |
+
|
17 |
+
# Test single boolean input
|
18 |
+
self.assertEqual(get_gpu(True), 0 if gpu_available else "cpu")
|
19 |
+
self.assertEqual(get_gpu(False), "cpu")
|
20 |
+
|
21 |
+
# Test single string input
|
22 |
+
self.assertEqual(get_gpu("cpu"), "cpu")
|
23 |
+
self.assertEqual(get_gpu("gpu"), 0 if gpu_available else "cpu")
|
24 |
+
self.assertEqual(get_gpu("cuda"), 0 if gpu_available else "cpu")
|
25 |
+
|
26 |
+
# Test single integer input
|
27 |
+
self.assertEqual(get_gpu(0), 0 if gpu_available else "cpu")
|
28 |
+
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
|
29 |
+
|
30 |
+
# Test list input with unique elements
|
31 |
+
self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
|
32 |
+
|
33 |
+
# Test list input with duplicate elements
|
34 |
+
self.assertEqual(get_gpu([0, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])
|
35 |
+
|
36 |
+
# Test list input with duplicate elements of different types
|
37 |
+
self.assertEqual(get_gpu([True, 0, "gpu"]), [0] if gpu_available else ["cpu", "cpu", "cpu"])
|
38 |
+
|
39 |
+
# Test list input with all integers
|
40 |
+
self.assertEqual(get_gpu(list(range(gpu_count))),
|
41 |
+
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
|
42 |
+
|
43 |
+
with self.assertRaises(ValueError):
|
44 |
+
get_gpu("invalid")
|
45 |
+
|
46 |
+
with self.assertRaises(ValueError):
|
47 |
+
get_gpu(torch.cuda.device_count())
|
48 |
+
|
49 |
+
def test_slice_embeddings(self):
|
50 |
+
embeddings = np.random.rand(10, 5)
|
51 |
+
num_sentences = [3, 2, 5]
|
52 |
+
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
53 |
+
self.assertTrue(
|
54 |
+
all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
|
55 |
+
expected_output))
|
56 |
+
)
|
57 |
+
|
58 |
+
num_sentences_nested = [[2, 1], [3, 4]]
|
59 |
+
expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
|
60 |
+
self.assertTrue(
|
61 |
+
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
|
62 |
+
)
|
63 |
+
|
64 |
+
with self.assertRaises(TypeError):
|
65 |
+
slice_embeddings(embeddings, "invalid")
|
66 |
+
|
67 |
+
def test_is_nested_list_of_type(self):
|
68 |
+
# Test case: Depth 0, single element matching element_type
|
69 |
+
self.assertTrue(is_nested_list_of_type("test", str, 0))
|
70 |
+
|
71 |
+
# Test case: Depth 0, single element not matching element_type
|
72 |
+
self.assertFalse(is_nested_list_of_type("test", int, 0))
|
73 |
+
|
74 |
+
# Test case: Depth 1, list of elements matching element_type
|
75 |
+
self.assertTrue(is_nested_list_of_type(["apple", "banana"], str, 1))
|
76 |
+
|
77 |
+
# Test case: Depth 1, list of elements not matching element_type
|
78 |
+
self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 1))
|
79 |
+
|
80 |
+
# Test case: Depth 0 (Wrong), list of elements matching element_type
|
81 |
+
self.assertFalse(is_nested_list_of_type([1, 2, 3], str, 0))
|
82 |
+
|
83 |
+
# Depth 2
|
84 |
+
self.assertTrue(is_nested_list_of_type([[1, 2], [3, 4]], int, 2))
|
85 |
+
self.assertTrue(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2))
|
86 |
+
self.assertFalse(is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2))
|
87 |
+
|
88 |
+
# Depth 3
|
89 |
+
self.assertFalse(is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3))
|
90 |
+
self.assertTrue(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3))
|
91 |
+
|
92 |
+
with self.assertRaises(ValueError):
|
93 |
+
is_nested_list_of_type([1, 2], int, -1)
|
94 |
+
|
95 |
+
def test_flatten_list(self):
|
96 |
+
self.assertEqual(flatten_list([1, [2, 3], [[4], 5]]), [1, 2, 3, 4, 5])
|
97 |
+
self.assertEqual(flatten_list([]), [])
|
98 |
+
self.assertEqual(flatten_list([1, 2, 3]), [1, 2, 3])
|
99 |
+
self.assertEqual(flatten_list([[[[1]]]]), [1])
|
100 |
+
|
101 |
+
def test_compute_f1(self):
|
102 |
+
self.assertAlmostEqual(compute_f1(0.5, 0.5), 0.5)
|
103 |
+
self.assertAlmostEqual(compute_f1(1, 0), 0.0)
|
104 |
+
self.assertAlmostEqual(compute_f1(0, 1), 0.0)
|
105 |
+
self.assertAlmostEqual(compute_f1(1, 1), 1.0)
|
106 |
+
|
107 |
+
def test_scores(self):
|
108 |
+
scores = Scores(precision=0.8, recall=[0.7, 0.9])
|
109 |
+
self.assertAlmostEqual(scores.f1, compute_f1(0.8, statistics.fmean([0.7, 0.9])))
|
110 |
+
|
111 |
+
|
112 |
+
class TestSBertEncoder(unittest.TestCase):
|
113 |
+
def setUp(self, device=None):
|
114 |
+
if device is None:
|
115 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
116 |
+
else:
|
117 |
+
self.device = device
|
118 |
+
self.model_name = "stsb-roberta-large"
|
119 |
+
self.batch_size = 8
|
120 |
+
self.verbose = False
|
121 |
+
self.encoder = SBertEncoder(self.model_name, self.device, self.batch_size, self.verbose)
|
122 |
+
|
123 |
+
def test_initialization(self):
|
124 |
+
self.assertIsInstance(self.encoder.model, SentenceTransformer)
|
125 |
+
self.assertEqual(self.encoder.device, self.device)
|
126 |
+
self.assertEqual(self.encoder.batch_size, self.batch_size)
|
127 |
+
self.assertEqual(self.encoder.verbose, self.verbose)
|
128 |
+
|
129 |
+
def test_encode_single_device(self):
|
130 |
+
sentences = ["This is a test sentence.", "Here is another sentence."]
|
131 |
+
embeddings = self.encoder.encode(sentences)
|
132 |
+
self.assertIsInstance(embeddings, np.ndarray)
|
133 |
+
self.assertEqual(embeddings.shape[0], len(sentences))
|
134 |
+
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
|
135 |
+
|
136 |
+
def test_encode_multi_device(self):
|
137 |
+
if torch.cuda.device_count() < 2:
|
138 |
+
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
|
139 |
+
else:
|
140 |
+
devices = ["cuda:0", "cuda:1"]
|
141 |
+
self.setUp(devices)
|
142 |
+
sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
|
143 |
+
embeddings = self.encoder.encode(sentences)
|
144 |
+
self.assertIsInstance(embeddings, np.ndarray)
|
145 |
+
self.assertEqual(embeddings.shape[0], 3)
|
146 |
+
self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
|
147 |
+
|
148 |
+
|
149 |
+
class TestGetEncoder(unittest.TestCase):
|
150 |
+
def test_get_sbert_encoder(self):
|
151 |
+
model_name = "stsb-roberta-large"
|
152 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
153 |
+
batch_size = 8
|
154 |
+
verbose = False
|
155 |
+
|
156 |
+
encoder = get_encoder(model_name, device, batch_size, verbose)
|
157 |
+
self.assertIsInstance(encoder, SBertEncoder)
|
158 |
+
self.assertEqual(encoder.device, device)
|
159 |
+
self.assertEqual(encoder.batch_size, batch_size)
|
160 |
+
self.assertEqual(encoder.verbose, verbose)
|
161 |
+
|
162 |
+
def test_get_use_encoder(self):
|
163 |
+
model_name = "use"
|
164 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
165 |
+
batch_size = 8
|
166 |
+
verbose = False
|
167 |
+
|
168 |
+
encoder = get_encoder(model_name, device, batch_size, verbose)
|
169 |
+
self.assertIsInstance(encoder, SBertEncoder) # SBertEncoder is returned for "use" for now
|
170 |
+
# Uncomment below when implementing USE class
|
171 |
+
# self.assertIsInstance(encoder, USE)
|
172 |
+
# self.assertEqual(encoder.model_name, model_name)
|
173 |
+
# self.assertEqual(encoder.device, device)
|
174 |
+
# self.assertEqual(encoder.batch_size, batch_size)
|
175 |
+
# self.assertEqual(encoder.verbose, verbose)
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == '__main__':
|
179 |
+
unittest.main()
|
type_aliases.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
from numpy.typing import NDArray
|
4 |
+
|
5 |
+
NumSentencesType = Union[List[int], List[List[int]]]
|
6 |
+
EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
|
7 |
+
PREDICTION_TYPE = Union[List[str], List[List[str]]]
|
8 |
+
REFERENCE_TYPE = Union[List[str], List[List[str]], List[List[List[str]]]]
|
9 |
+
DEVICE_TYPE = Union[bool, str, int, List[Union[str, int]]]
|
10 |
+
ENCODER_DEVICE_TYPE = Union[str, int, List[Union[str, int]]]
|
utils.py
CHANGED
@@ -1,13 +1,81 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
import statistics
|
3 |
import sys
|
|
|
4 |
from typing import List, Union
|
5 |
|
|
|
6 |
from numpy.typing import NDArray
|
7 |
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
@@ -22,10 +90,10 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
|
|
22 |
result, _ = _slice_embeddings(0, num_sentences)
|
23 |
return result
|
24 |
elif isinstance(num_sentences, list) and all(
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
):
|
30 |
nested_result = []
|
31 |
start_idx = 0
|
@@ -38,11 +106,11 @@ def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> Em
|
|
38 |
raise TypeError(f"Incorrect Type for {num_sentences=}")
|
39 |
|
40 |
|
41 |
-
def
|
42 |
if depth == 0:
|
43 |
-
return isinstance(
|
44 |
elif depth > 0:
|
45 |
-
return isinstance(
|
46 |
else:
|
47 |
raise ValueError("Depth can't be negative")
|
48 |
|
|
|
|
|
1 |
import statistics
|
2 |
import sys
|
3 |
+
from dataclasses import dataclass
|
4 |
from typing import List, Union
|
5 |
|
6 |
+
import torch
|
7 |
from numpy.typing import NDArray
|
8 |
|
9 |
+
from type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType
|
10 |
|
11 |
+
|
12 |
+
def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
|
13 |
+
"""
|
14 |
+
Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s):
|
18 |
+
- bool: If True, returns 0 if CUDA is available, otherwise returns "cpu".
|
19 |
+
- str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available
|
20 |
+
and the input is not "cpu", otherwise returns "cpu".
|
21 |
+
- int: Should be a valid GPU index. Returns the index if CUDA is available and valid,
|
22 |
+
otherwise returns "cpu".
|
23 |
+
- List[Union[str, int]]: List containing combinations of the str/int. Processes each
|
24 |
+
element and returns a list of corresponding results.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Union[str, int, List[Union[str, int]]]: Depending on the input type:
|
28 |
+
- str: Returns "cpu" if no GPU is available or the input is "cpu".
|
29 |
+
- int: Returns the GPU index if valid and CUDA is available.
|
30 |
+
- List[Union[str, int]]: Returns a list of strings and/or integers based on the input list.
|
31 |
+
|
32 |
+
Raises:
|
33 |
+
ValueError: If the input gpu type is not recognized or invalid.
|
34 |
+
ValueError: If a string input is not one of ["cpu", "gpu", "cuda"].
|
35 |
+
ValueError: If an integer input is outside the valid range of GPU indices.
|
36 |
+
|
37 |
+
Notes:
|
38 |
+
- This function checks CUDA availability using torch.cuda.is_available() and counts
|
39 |
+
available GPUs using torch.cuda.device_count().
|
40 |
+
- Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda").
|
41 |
+
- The function ensures robust error handling for invalid input types or out-of-range indices.
|
42 |
+
"""
|
43 |
+
|
44 |
+
# Ensure gpu index is within the range of total available gpus
|
45 |
+
gpu_available = torch.cuda.is_available()
|
46 |
+
gpu_count = torch.cuda.device_count()
|
47 |
+
correct_strs = ["cpu", "gpu", "cuda"]
|
48 |
+
|
49 |
+
def _get_single_device(gpu_item):
|
50 |
+
if isinstance(gpu_item, bool):
|
51 |
+
return 0 if gpu_item and gpu_available else "cpu"
|
52 |
+
elif isinstance(gpu_item, str):
|
53 |
+
if gpu_item.lower() not in correct_strs:
|
54 |
+
raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}")
|
55 |
+
return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu"
|
56 |
+
elif isinstance(gpu_item, int):
|
57 |
+
if gpu_item >= gpu_count:
|
58 |
+
raise ValueError(
|
59 |
+
f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}"
|
60 |
+
)
|
61 |
+
return gpu_item if gpu_available else "cpu"
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.")
|
64 |
+
|
65 |
+
if isinstance(gpu, list):
|
66 |
+
seen_indices = set()
|
67 |
+
result = []
|
68 |
+
for item in gpu:
|
69 |
+
device = _get_single_device(item)
|
70 |
+
if isinstance(device, int):
|
71 |
+
if device not in seen_indices:
|
72 |
+
seen_indices.add(device)
|
73 |
+
result.append(device)
|
74 |
+
else:
|
75 |
+
result.append(device)
|
76 |
+
return result
|
77 |
+
else:
|
78 |
+
return _get_single_device(gpu)
|
79 |
|
80 |
|
81 |
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
|
|
|
90 |
result, _ = _slice_embeddings(0, num_sentences)
|
91 |
return result
|
92 |
elif isinstance(num_sentences, list) and all(
|
93 |
+
isinstance(sublist, list) and all(
|
94 |
+
isinstance(item, int) for item in sublist
|
95 |
+
)
|
96 |
+
for sublist in num_sentences
|
97 |
):
|
98 |
nested_result = []
|
99 |
start_idx = 0
|
|
|
106 |
raise TypeError(f"Incorrect Type for {num_sentences=}")
|
107 |
|
108 |
|
109 |
+
def is_nested_list_of_type(lst_obj, element_type, depth: int) -> bool:
|
110 |
if depth == 0:
|
111 |
+
return isinstance(lst_obj, element_type)
|
112 |
elif depth > 0:
|
113 |
+
return isinstance(lst_obj, list) and all(is_nested_list_of_type(item, element_type, depth - 1) for item in lst_obj)
|
114 |
else:
|
115 |
raise ValueError("Depth can't be negative")
|
116 |
|