feat: use enum for the vector type
Browse files- modeling_jina_embeddings_v4.py +23 -14
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -30,8 +30,12 @@ class PromptType(str, Enum):
|
|
| 30 |
passage = "passage"
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
| 34 |
-
VECTOR_TYPES = ["single", "multi"]
|
| 35 |
|
| 36 |
|
| 37 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
@@ -320,7 +324,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 320 |
task_label: Union[str, List[str]],
|
| 321 |
processor_fn: Callable,
|
| 322 |
desc: str,
|
| 323 |
-
vector_type: str =
|
| 324 |
return_numpy: bool = False,
|
| 325 |
batch_size: int = 32,
|
| 326 |
truncate_dim: Optional[int] = None,
|
|
@@ -340,7 +344,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 341 |
):
|
| 342 |
embeddings = self(**batch, task_label=task_label)
|
| 343 |
-
if vector_type
|
|
|
|
| 344 |
embeddings = embeddings.single_vec_emb
|
| 345 |
if truncate_dim is not None:
|
| 346 |
embeddings = embeddings[:, :truncate_dim]
|
|
@@ -357,7 +362,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 357 |
|
| 358 |
def _validate_encoding_params(
|
| 359 |
self,
|
| 360 |
-
vector_type: Optional[str] = None,
|
| 361 |
truncate_dim: Optional[int] = None,
|
| 362 |
prompt_name: Optional[str] = None,
|
| 363 |
) -> Dict[str, Any]:
|
|
@@ -374,13 +379,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 374 |
else PREFIX_DICT["query"]
|
| 375 |
)
|
| 376 |
|
| 377 |
-
vector_type = vector_type or
|
| 378 |
-
if vector_type
|
| 379 |
-
|
| 380 |
-
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
|
| 381 |
-
)
|
| 382 |
else:
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 386 |
if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
|
|
@@ -413,7 +422,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 413 |
task: Optional[str] = None,
|
| 414 |
max_length: int = 8192,
|
| 415 |
batch_size: int = 8,
|
| 416 |
-
vector_type: Optional[str] = None,
|
| 417 |
return_numpy: bool = False,
|
| 418 |
truncate_dim: Optional[int] = None,
|
| 419 |
prompt_name: Optional[str] = None,
|
|
@@ -425,7 +434,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 425 |
texts: List of text strings to encode
|
| 426 |
max_length: Maximum token length for text processing
|
| 427 |
batch_size: Number of texts to process at once
|
| 428 |
-
vector_type: Type of embedding vector to generate (
|
| 429 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 430 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 431 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
@@ -477,7 +486,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 477 |
images: List[Union[str, Image.Image]],
|
| 478 |
task: Optional[str] = None,
|
| 479 |
batch_size: int = 8,
|
| 480 |
-
vector_type: Optional[str] = None,
|
| 481 |
return_numpy: bool = False,
|
| 482 |
truncate_dim: Optional[int] = None,
|
| 483 |
max_pixels: Optional[int] = None,
|
|
@@ -488,7 +497,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 488 |
Args:
|
| 489 |
images: List of PIL images, URLs, or local file paths to encode
|
| 490 |
batch_size: Number of images to process at once
|
| 491 |
-
vector_type: Type of embedding vector to generate (
|
| 492 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 493 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 494 |
max_pixels: Maximum number of pixels to process per image
|
|
|
|
| 30 |
passage = "passage"
|
| 31 |
|
| 32 |
|
| 33 |
+
class VectorType(str, Enum):
|
| 34 |
+
single = "single"
|
| 35 |
+
multi = "multi"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
|
|
| 324 |
task_label: Union[str, List[str]],
|
| 325 |
processor_fn: Callable,
|
| 326 |
desc: str,
|
| 327 |
+
vector_type: Union[str, VectorType] = VectorType.single,
|
| 328 |
return_numpy: bool = False,
|
| 329 |
batch_size: int = 32,
|
| 330 |
truncate_dim: Optional[int] = None,
|
|
|
|
| 344 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 345 |
):
|
| 346 |
embeddings = self(**batch, task_label=task_label)
|
| 347 |
+
vector_type_str = vector_type.value if isinstance(vector_type, VectorType) else vector_type
|
| 348 |
+
if vector_type_str == VectorType.single.value:
|
| 349 |
embeddings = embeddings.single_vec_emb
|
| 350 |
if truncate_dim is not None:
|
| 351 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
|
| 362 |
|
| 363 |
def _validate_encoding_params(
|
| 364 |
self,
|
| 365 |
+
vector_type: Optional[Union[str, VectorType]] = None,
|
| 366 |
truncate_dim: Optional[int] = None,
|
| 367 |
prompt_name: Optional[str] = None,
|
| 368 |
) -> Dict[str, Any]:
|
|
|
|
| 379 |
else PREFIX_DICT["query"]
|
| 380 |
)
|
| 381 |
|
| 382 |
+
vector_type = vector_type or VectorType.single
|
| 383 |
+
if isinstance(vector_type, VectorType):
|
| 384 |
+
encode_kwargs["vector_type"] = vector_type.value
|
|
|
|
|
|
|
| 385 |
else:
|
| 386 |
+
try:
|
| 387 |
+
vector_type_enum = VectorType(vector_type)
|
| 388 |
+
encode_kwargs["vector_type"] = vector_type_enum.value
|
| 389 |
+
except ValueError:
|
| 390 |
+
raise ValueError(
|
| 391 |
+
f"Invalid vector_type: {vector_type}. Must be one of {[v.value for v in VectorType]}."
|
| 392 |
+
)
|
| 393 |
|
| 394 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 395 |
if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
|
|
|
|
| 422 |
task: Optional[str] = None,
|
| 423 |
max_length: int = 8192,
|
| 424 |
batch_size: int = 8,
|
| 425 |
+
vector_type: Optional[Union[str, VectorType]] = None,
|
| 426 |
return_numpy: bool = False,
|
| 427 |
truncate_dim: Optional[int] = None,
|
| 428 |
prompt_name: Optional[str] = None,
|
|
|
|
| 434 |
texts: List of text strings to encode
|
| 435 |
max_length: Maximum token length for text processing
|
| 436 |
batch_size: Number of texts to process at once
|
| 437 |
+
vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
|
| 438 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 439 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 440 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
|
|
| 486 |
images: List[Union[str, Image.Image]],
|
| 487 |
task: Optional[str] = None,
|
| 488 |
batch_size: int = 8,
|
| 489 |
+
vector_type: Optional[Union[str, VectorType]] = None,
|
| 490 |
return_numpy: bool = False,
|
| 491 |
truncate_dim: Optional[int] = None,
|
| 492 |
max_pixels: Optional[int] = None,
|
|
|
|
| 497 |
Args:
|
| 498 |
images: List of PIL images, URLs, or local file paths to encode
|
| 499 |
batch_size: Number of images to process at once
|
| 500 |
+
vector_type: Type of embedding vector to generate (VectorType.single or VectorType.multi)
|
| 501 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 502 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 503 |
max_pixels: Maximum number of pixels to process per image
|