feat: avoid the redundant words in the variables
Browse files
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -31,7 +31,7 @@ class PromptType(str, Enum):
|
|
| 31 |
|
| 32 |
|
| 33 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
| 34 |
-
VECTOR_TYPES = ["
|
| 35 |
|
| 36 |
|
| 37 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
@@ -284,8 +284,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 284 |
attention_mask (torch.Tensor): The attention mask tensor.
|
| 285 |
Returns:
|
| 286 |
JinaEmbeddingsV4ModelOutput:
|
| 287 |
-
|
| 288 |
-
|
| 289 |
"""
|
| 290 |
# Forward pass through the VLM
|
| 291 |
hidden_states = self.get_last_hidden_states(
|
|
@@ -320,7 +320,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 320 |
task_label: Union[str, List[str]],
|
| 321 |
processor_fn: Callable,
|
| 322 |
desc: str,
|
| 323 |
-
vector_type: str = "
|
| 324 |
return_numpy: bool = False,
|
| 325 |
batch_size: int = 32,
|
| 326 |
truncate_dim: Optional[int] = None,
|
|
@@ -340,7 +340,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 341 |
):
|
| 342 |
embeddings = self(**batch, task_label=task_label)
|
| 343 |
-
if vector_type == "
|
| 344 |
embeddings = embeddings.single_vec_emb
|
| 345 |
if truncate_dim is not None:
|
| 346 |
embeddings = embeddings[:, :truncate_dim]
|
|
@@ -374,7 +374,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 374 |
else PREFIX_DICT["query"]
|
| 375 |
)
|
| 376 |
|
| 377 |
-
vector_type = vector_type or "
|
| 378 |
if vector_type not in VECTOR_TYPES:
|
| 379 |
raise ValueError(
|
| 380 |
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
|
|
@@ -425,7 +425,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 425 |
texts: List of text strings to encode
|
| 426 |
max_length: Maximum token length for text processing
|
| 427 |
batch_size: Number of texts to process at once
|
| 428 |
-
vector_type: Type of embedding vector to generate ('
|
| 429 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 430 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 431 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
@@ -488,7 +488,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
|
| 488 |
Args:
|
| 489 |
images: List of PIL images, URLs, or local file paths to encode
|
| 490 |
batch_size: Number of images to process at once
|
| 491 |
-
vector_type: Type of embedding vector to generate ('
|
| 492 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 493 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 494 |
max_pixels: Maximum number of pixels to process per image
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
PREFIX_DICT = {"query": "Query", "passage": "Passage"}
|
| 34 |
+
VECTOR_TYPES = ["single", "multi"]
|
| 35 |
|
| 36 |
|
| 37 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
|
|
| 284 |
attention_mask (torch.Tensor): The attention mask tensor.
|
| 285 |
Returns:
|
| 286 |
JinaEmbeddingsV4ModelOutput:
|
| 287 |
+
single (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
|
| 288 |
+
multi (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
|
| 289 |
"""
|
| 290 |
# Forward pass through the VLM
|
| 291 |
hidden_states = self.get_last_hidden_states(
|
|
|
|
| 320 |
task_label: Union[str, List[str]],
|
| 321 |
processor_fn: Callable,
|
| 322 |
desc: str,
|
| 323 |
+
vector_type: str = "single",
|
| 324 |
return_numpy: bool = False,
|
| 325 |
batch_size: int = 32,
|
| 326 |
truncate_dim: Optional[int] = None,
|
|
|
|
| 340 |
device_type=torch.device(self.device).type, dtype=torch.bfloat16
|
| 341 |
):
|
| 342 |
embeddings = self(**batch, task_label=task_label)
|
| 343 |
+
if vector_type == "single":
|
| 344 |
embeddings = embeddings.single_vec_emb
|
| 345 |
if truncate_dim is not None:
|
| 346 |
embeddings = embeddings[:, :truncate_dim]
|
|
|
|
| 374 |
else PREFIX_DICT["query"]
|
| 375 |
)
|
| 376 |
|
| 377 |
+
vector_type = vector_type or "single"
|
| 378 |
if vector_type not in VECTOR_TYPES:
|
| 379 |
raise ValueError(
|
| 380 |
f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
|
|
|
|
| 425 |
texts: List of text strings to encode
|
| 426 |
max_length: Maximum token length for text processing
|
| 427 |
batch_size: Number of texts to process at once
|
| 428 |
+
vector_type: Type of embedding vector to generate ('single' or 'multi')
|
| 429 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 430 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 431 |
prompt_name: Type of text being encoded ('query' or 'passage')
|
|
|
|
| 488 |
Args:
|
| 489 |
images: List of PIL images, URLs, or local file paths to encode
|
| 490 |
batch_size: Number of images to process at once
|
| 491 |
+
vector_type: Type of embedding vector to generate ('single' or 'multi')
|
| 492 |
return_numpy: Whether to return numpy arrays instead of torch tensors
|
| 493 |
truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
|
| 494 |
max_pixels: Maximum number of pixels to process per image
|