nan commited on
Commit
f7df96a
·
1 Parent(s): fe4c51b

feat: return a list when the input is a list

Browse files
Files changed (1) hide show
  1. modeling_jina_embeddings_v4.py +6 -2
modeling_jina_embeddings_v4.py CHANGED
@@ -435,6 +435,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
435
  prefix=encode_kwargs.pop("prefix"),
436
  )
437
 
 
 
438
  if isinstance(texts, str):
439
  texts = [texts]
440
 
@@ -449,7 +451,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
449
  **encode_kwargs,
450
  )
451
 
452
- return embeddings if len(texts) > 1 else embeddings[0]
453
 
454
  def _load_images_if_needed(
455
  self, images: List[Union[str, Image.Image]]
@@ -497,6 +499,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
497
  encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
498
  task = self._validate_task(task)
499
 
 
 
500
  # Convert single image to list
501
  if isinstance(images, (str, Image.Image)):
502
  images = [images]
@@ -516,7 +520,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
516
  if max_pixels:
517
  self.processor.image_processor.max_pixels = default_max_pixels
518
 
519
- return embeddings if len(images) > 1 else embeddings[0]
520
 
521
  @classmethod
522
  def from_pretrained(
 
435
  prefix=encode_kwargs.pop("prefix"),
436
  )
437
 
438
+ return_list = isinstance(texts, list)
439
+
440
  if isinstance(texts, str):
441
  texts = [texts]
442
 
 
451
  **encode_kwargs,
452
  )
453
 
454
+ return embeddings if return_list else embeddings[0]
455
 
456
  def _load_images_if_needed(
457
  self, images: List[Union[str, Image.Image]]
 
499
  encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
500
  task = self._validate_task(task)
501
 
502
+ return_list = isinstance(images, list)
503
+
504
  # Convert single image to list
505
  if isinstance(images, (str, Image.Image)):
506
  images = [images]
 
520
  if max_pixels:
521
  self.processor.image_processor.max_pixels = default_max_pixels
522
 
523
+ return embeddings if return_list else embeddings[0]
524
 
525
  @classmethod
526
  def from_pretrained(