from typing import Any, Dict, List, Literal, Optional, Union import torch from PIL import Image from torch import nn from transformers import AutoConfig, AutoProcessor, AutoModel class Transformer(nn.Module): save_in_root: bool = True def __init__( self, model_name_or_path: str = 'jinaai/jina-embeddings-v4', max_seq_length: Optional[int] = None, config_args: Optional[Dict[str, Any]] = None, model_args: Optional[Dict[str, Any]] = None, tokenizer_args: Optional[Dict[str, Any]] = None, cache_dir: Optional[str] = None, backend: Literal['torch', 'onnx', 'openvino'] = 'torch', **kwargs, ) -> None: super(Transformer, self).__init__() if backend != 'torch': raise ValueError( f'Backend \'{backend}\' is not supported, please use \'torch\' instead' ) config_kwargs = config_args or {} model_kwargs = model_args or {} tokenizer_kwargs = tokenizer_args or {} self.config = AutoConfig.from_pretrained( model_name_or_path, cache_dir=cache_dir, **config_kwargs ) self.default_task = model_args.pop('default_task', None) if self.default_task and self.default_task not in self.config.task_names: raise ValueError(f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}.") self.model = AutoModel.from_pretrained( model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs ) self.processor = AutoProcessor.from_pretrained( model_name_or_path, cache_dir=cache_dir, **tokenizer_kwargs, ) self.max_seq_length = max_seq_length or 8192 def tokenize( self, texts: List[Union[str, Image.Image]], padding: Union[str, bool] = True ) -> Dict[str, torch.Tensor]: encoding = {} text_indices = [] image_indices = [] for i, text in enumerate(texts): if isinstance(text, str): text_indices.append(i) elif isinstance(text, Image.Image): image_indices.append(i) else: raise ValueError(f'Invalid input type: {type(text)}') if text_indices: _texts = [texts[i] for i in text_indices] text_features = self.processor.process_texts(_texts, max_length=self.max_seq_length) for key, value in text_features.items(): encoding[f'text_{key}'] = value encoding['text_indices'] = text_indices if image_indices: _images = [texts[i] for i in image_indices] img_features = self.processor.process_images(_images) for key, value in img_features.items(): encoding[f'image_{key}'] = value encoding['image_indices'] = image_indices return encoding def forward(self, features: Dict[str, torch.Tensor], task: Optional[str] = None) -> Dict[str, torch.Tensor]: self.model.eval() if task is None: if self.default_task is None: raise ValueError( "Task must be specified before encoding data. You can set it either during " "loading the model (e.g., model_kwargs={'default_task': 'retrieval'}) or " "pass it as an argument to the encode method (e.g., model.encode(texts, task='retrieval'))." ) task = self.default_task else: if task not in self.config.task_names: raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.") device = self.model.device.type all_embeddings = [] with torch.no_grad(): if any(k.startswith('text_') for k in features.keys()): text_batch = {k[len('text_'):]: v.to(device) for k, v in features.items() if k.startswith('text_') and k != 'text_indices'} text_indices = features.get('text_indices', []) with torch.autocast(device_type=device): text_embeddings = self.model(**text_batch, task_label=task).single_vec_emb if self.config.truncate_dim: text_embeddings = text_embeddings[:, :self.config.truncate_dim] for i, embedding in enumerate(text_embeddings): all_embeddings.append((text_indices[i], embedding)) if any(k.startswith('image_') for k in features.keys()): image_batch = {k[len('image_'):]: v.to(device) for k, v in features.items() if k.startswith('image_') and k != 'image_indices'} image_indices = features.get('image_indices', []) with torch.autocast(device_type=device): img_embeddings = self.model(**image_batch, task_label=task).single_vec_emb if self.config.truncate_dim: img_embeddings = img_embeddings[:, :self.config.truncate_dim] for i, embedding in enumerate(img_embeddings): all_embeddings.append((image_indices[i], embedding)) if not all_embeddings: raise RuntimeError('No embeddings were generated') all_embeddings.sort(key=lambda x: x[0]) # sort by original index combined_embeddings = torch.stack([emb for _, emb in all_embeddings]) features['sentence_embedding'] = combined_embeddings return features