nan commited on
Commit
7f10796
·
1 Parent(s): 455d3b0

feat: avoid the redundant words in the variables

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +8 -8
modeling_jina_embeddings_v4.py CHANGED
@@ -31,7 +31,7 @@ class PromptType(str, Enum):
31
 
32
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
34
- VECTOR_TYPES = ["single_vector", "multi_vector"]
35
 
36
 
37
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
@@ -284,8 +284,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
284
  attention_mask (torch.Tensor): The attention mask tensor.
285
  Returns:
286
  JinaEmbeddingsV4ModelOutput:
287
- single_vector (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
288
- multi_vector (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
289
  """
290
  # Forward pass through the VLM
291
  hidden_states = self.get_last_hidden_states(
@@ -320,7 +320,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
- vector_type: str = "single_vector",
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
@@ -340,7 +340,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
- if vector_type == "single_vector":
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
@@ -374,7 +374,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
374
  else PREFIX_DICT["query"]
375
  )
376
 
377
- vector_type = vector_type or "single_vector"
378
  if vector_type not in VECTOR_TYPES:
379
  raise ValueError(
380
  f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
@@ -425,7 +425,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
425
  texts: List of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
- vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
429
  return_numpy: Whether to return numpy arrays instead of torch tensors
430
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
431
  prompt_name: Type of text being encoded ('query' or 'passage')
@@ -488,7 +488,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
488
  Args:
489
  images: List of PIL images, URLs, or local file paths to encode
490
  batch_size: Number of images to process at once
491
- vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
492
  return_numpy: Whether to return numpy arrays instead of torch tensors
493
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
494
  max_pixels: Maximum number of pixels to process per image
 
31
 
32
 
33
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
34
+ VECTOR_TYPES = ["single", "multi"]
35
 
36
 
37
  class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
 
284
  attention_mask (torch.Tensor): The attention mask tensor.
285
  Returns:
286
  JinaEmbeddingsV4ModelOutput:
287
+ single (torch.Tensor): Single-vector embeddings of shape (batch_size, dim).
288
+ multi (torch.Tensor): Multi-vector embeddings of shape (batch_size, num_tokens, dim).
289
  """
290
  # Forward pass through the VLM
291
  hidden_states = self.get_last_hidden_states(
 
320
  task_label: Union[str, List[str]],
321
  processor_fn: Callable,
322
  desc: str,
323
+ vector_type: str = "single",
324
  return_numpy: bool = False,
325
  batch_size: int = 32,
326
  truncate_dim: Optional[int] = None,
 
340
  device_type=torch.device(self.device).type, dtype=torch.bfloat16
341
  ):
342
  embeddings = self(**batch, task_label=task_label)
343
+ if vector_type == "single":
344
  embeddings = embeddings.single_vec_emb
345
  if truncate_dim is not None:
346
  embeddings = embeddings[:, :truncate_dim]
 
374
  else PREFIX_DICT["query"]
375
  )
376
 
377
+ vector_type = vector_type or "single"
378
  if vector_type not in VECTOR_TYPES:
379
  raise ValueError(
380
  f"Invalid vector_type: {vector_type}. Must be one of {VECTOR_TYPES}."
 
425
  texts: List of text strings to encode
426
  max_length: Maximum token length for text processing
427
  batch_size: Number of texts to process at once
428
+ vector_type: Type of embedding vector to generate ('single' or 'multi')
429
  return_numpy: Whether to return numpy arrays instead of torch tensors
430
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
431
  prompt_name: Type of text being encoded ('query' or 'passage')
 
488
  Args:
489
  images: List of PIL images, URLs, or local file paths to encode
490
  batch_size: Number of images to process at once
491
+ vector_type: Type of embedding vector to generate ('single' or 'multi')
492
  return_numpy: Whether to return numpy arrays instead of torch tensors
493
  truncate_dim: Dimension to truncate embeddings to (128, 256, 512, or 1024)
494
  max_pixels: Maximum number of pixels to process per image