refactor-model-loading (#4)
Browse files- feat: loading through jev4 class and stylistic changes (f7cb47c6b07716483dbf5fd311928026ce7cd27a)
- modeling_jina_embeddings_v4.py +127 -87
modeling_jina_embeddings_v4.py
CHANGED
|
@@ -1,25 +1,23 @@
|
|
| 1 |
-
import os
|
| 2 |
import math
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
from dataclasses import dataclass
|
|
|
|
|
|
|
| 6 |
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
|
| 7 |
-
|
|
|
|
| 8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from torch import nn
|
| 10 |
from torch.utils.data import DataLoader
|
| 11 |
-
|
| 12 |
-
from functools import partial
|
| 13 |
-
from PIL import Image
|
| 14 |
from tqdm import tqdm
|
| 15 |
-
from enum import Enum
|
| 16 |
-
from peft.utils.hotswap import hotswap_adapter
|
| 17 |
-
|
| 18 |
from transformers import BatchFeature
|
| 19 |
-
|
| 20 |
-
from transformers.models.qwen2_5_vl import
|
| 21 |
-
|
| 22 |
-
from huggingface_hub import snapshot_download
|
| 23 |
|
| 24 |
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
|
| 25 |
|
|
@@ -28,6 +26,13 @@ class PromptType(str, Enum):
|
|
| 28 |
query = "query"
|
| 29 |
passage = "passage"
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
| 32 |
def __init__(self, *args, **kwargs) -> None:
|
| 33 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
|
@@ -58,8 +63,12 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
| 58 |
images = cast(List[List[Image.Image]], images)
|
| 59 |
text_doc = []
|
| 60 |
for i in range(len(images)):
|
| 61 |
-
conversation = [
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
text_doc.append(template[self.assistant_prefix_len :])
|
| 64 |
|
| 65 |
else:
|
|
@@ -78,7 +87,16 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
| 78 |
max_length = max([len(pv) for pv in pixel_values])
|
| 79 |
|
| 80 |
pixel_values = [
|
| 81 |
-
torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
for pv in pixel_values
|
| 83 |
]
|
| 84 |
|
|
@@ -93,7 +111,11 @@ class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
|
| 93 |
padding: Optional[str] = None,
|
| 94 |
) -> BatchFeature:
|
| 95 |
|
| 96 |
-
max_length =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
padded_texts: List[str] = []
|
| 98 |
|
| 99 |
for text in texts:
|
|
@@ -127,7 +149,7 @@ class JinaEmbeddingsV4ModelOutput:
|
|
| 127 |
multi_vec_emb: Optional[torch.Tensor] = None
|
| 128 |
|
| 129 |
|
| 130 |
-
class
|
| 131 |
config_class = JinaEmbeddingsV4Config
|
| 132 |
main_input_name: ClassVar[str] = "doc_input_ids"
|
| 133 |
|
|
@@ -135,7 +157,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 135 |
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
| 136 |
self._init_projection_layers(config)
|
| 137 |
self.post_init()
|
| 138 |
-
self.processor = JinaEmbeddingsV4Processor.from_pretrained(
|
|
|
|
|
|
|
| 139 |
self.single_vector_projector_dim = config.single_vector_projector_dim
|
| 140 |
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
| 141 |
|
|
@@ -147,7 +171,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 147 |
) -> torch.Tensor:
|
| 148 |
if "pixel_values" in kwargs:
|
| 149 |
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
|
| 150 |
-
kwargs["pixel_values"] = torch.cat(
|
|
|
|
|
|
|
| 151 |
|
| 152 |
position_ids, rope_deltas = super().get_rope_index( # type: ignore
|
| 153 |
input_ids=input_ids,
|
|
@@ -155,7 +181,7 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 155 |
attention_mask=attention_mask,
|
| 156 |
)
|
| 157 |
|
| 158 |
-
kwargs[
|
| 159 |
|
| 160 |
outputs = super().forward(
|
| 161 |
input_ids,
|
|
@@ -199,14 +225,22 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 199 |
Project the hidden states to single-vector embeddings.
|
| 200 |
"""
|
| 201 |
if self._input_has_image(input_ids[0]): # got document image
|
| 202 |
-
img_start_pos = torch.where(
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
else: # got query text
|
| 207 |
-
pooled_output = torch.sum(
|
| 208 |
-
attention_mask, dim=1
|
| 209 |
-
)
|
| 210 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
| 211 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
| 212 |
|
|
@@ -248,15 +282,21 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 248 |
) # (batch_size, seq_length, hidden_size)
|
| 249 |
|
| 250 |
# Compute the embeddings
|
| 251 |
-
single_vec_emb = self.project_to_single_vector_embeddings(
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
return JinaEmbeddingsV4ModelOutput(
|
| 255 |
-
vlm_last_hidden_states=
|
|
|
|
|
|
|
| 256 |
single_vec_emb=single_vec_emb,
|
| 257 |
multi_vec_emb=multi_vec_emb,
|
| 258 |
)
|
| 259 |
-
|
| 260 |
def _process_batches(
|
| 261 |
self,
|
| 262 |
data: List[Union[str, Image.Image]],
|
|
@@ -284,7 +324,11 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 284 |
embeddings = embeddings.single_vec_emb
|
| 285 |
else:
|
| 286 |
embeddings = embeddings.multi_vec_emb
|
| 287 |
-
results.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
if return_numpy:
|
| 289 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
| 290 |
return [item for sublist in results for item in sublist]
|
|
@@ -298,7 +342,9 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 298 |
desc: Optional[str] = None,
|
| 299 |
**kwargs,
|
| 300 |
) -> List[torch.Tensor]:
|
| 301 |
-
processor_fn = partial(
|
|
|
|
|
|
|
| 302 |
return self._process_batches(
|
| 303 |
data=queries,
|
| 304 |
processor_fn=processor_fn,
|
|
@@ -325,17 +371,6 @@ class QwenVL25Embeddings(Qwen2_5_VLForConditionalGeneration):
|
|
| 325 |
**kwargs,
|
| 326 |
)
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
class JinaEmbeddingsV4Model:
|
| 331 |
-
"""
|
| 332 |
-
Wrapper class for QwenVL25Embeddings that handles the loading of models and adapters.
|
| 333 |
-
"""
|
| 334 |
-
|
| 335 |
-
def __init__(self, model, adapter_dir):
|
| 336 |
-
self.model = model
|
| 337 |
-
self.adapter_dir = adapter_dir
|
| 338 |
-
|
| 339 |
@classmethod
|
| 340 |
def from_pretrained(
|
| 341 |
cls,
|
|
@@ -345,48 +380,53 @@ class JinaEmbeddingsV4Model:
|
|
| 345 |
):
|
| 346 |
if "torch_dtype" not in kwargs:
|
| 347 |
kwargs["torch_dtype"] = "auto"
|
| 348 |
-
|
| 349 |
-
task = kwargs.pop(
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
else:
|
| 356 |
adapter_cache_path = snapshot_download(
|
| 357 |
-
repo_id=
|
| 358 |
-
allow_patterns=['adapters/*']
|
| 359 |
)
|
| 360 |
-
adapter_dir = os.path.join(adapter_cache_path,
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
+
import os
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from functools import partial
|
| 6 |
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
import torch
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
from peft import PeftModel
|
| 12 |
+
from peft.utils.hotswap import hotswap_adapter
|
| 13 |
+
from PIL import Image
|
| 14 |
from torch import nn
|
| 15 |
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
| 16 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 17 |
from transformers import BatchFeature
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
+
from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
|
| 20 |
+
Qwen2_5_VLProcessor)
|
|
|
|
| 21 |
|
| 22 |
from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
|
| 23 |
|
|
|
|
| 26 |
query = "query"
|
| 27 |
passage = "passage"
|
| 28 |
|
| 29 |
+
|
| 30 |
+
class TaskType(str, Enum):
|
| 31 |
+
retrieval = "retrieval"
|
| 32 |
+
code = "code"
|
| 33 |
+
text_matching = "text-matching"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
class JinaEmbeddingsV4Processor(Qwen2_5_VLProcessor):
|
| 37 |
def __init__(self, *args, **kwargs) -> None:
|
| 38 |
Qwen2_5_VLProcessor.__init__(self, *args, **kwargs)
|
|
|
|
| 63 |
images = cast(List[List[Image.Image]], images)
|
| 64 |
text_doc = []
|
| 65 |
for i in range(len(images)):
|
| 66 |
+
conversation = [
|
| 67 |
+
{"role": "user", "content": [{"type": "image"}] * len(images[i])}
|
| 68 |
+
]
|
| 69 |
+
template = self.apply_chat_template(
|
| 70 |
+
conversation, add_generation_prompt=False
|
| 71 |
+
)
|
| 72 |
text_doc.append(template[self.assistant_prefix_len :])
|
| 73 |
|
| 74 |
else:
|
|
|
|
| 87 |
max_length = max([len(pv) for pv in pixel_values])
|
| 88 |
|
| 89 |
pixel_values = [
|
| 90 |
+
torch.cat(
|
| 91 |
+
[
|
| 92 |
+
pv,
|
| 93 |
+
torch.zeros(
|
| 94 |
+
(max_length - len(pv), pv.shape[1]),
|
| 95 |
+
dtype=pv.dtype,
|
| 96 |
+
device=pv.device,
|
| 97 |
+
),
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
for pv in pixel_values
|
| 101 |
]
|
| 102 |
|
|
|
|
| 111 |
padding: Optional[str] = None,
|
| 112 |
) -> BatchFeature:
|
| 113 |
|
| 114 |
+
max_length = (
|
| 115 |
+
self.text_max_length
|
| 116 |
+
if max_length is None
|
| 117 |
+
else min(max_length, self.text_max_length)
|
| 118 |
+
)
|
| 119 |
padded_texts: List[str] = []
|
| 120 |
|
| 121 |
for text in texts:
|
|
|
|
| 149 |
multi_vec_emb: Optional[torch.Tensor] = None
|
| 150 |
|
| 151 |
|
| 152 |
+
class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
|
| 153 |
config_class = JinaEmbeddingsV4Config
|
| 154 |
main_input_name: ClassVar[str] = "doc_input_ids"
|
| 155 |
|
|
|
|
| 157 |
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
| 158 |
self._init_projection_layers(config)
|
| 159 |
self.post_init()
|
| 160 |
+
self.processor = JinaEmbeddingsV4Processor.from_pretrained(
|
| 161 |
+
self.name_or_path, trust_remote_code=True
|
| 162 |
+
)
|
| 163 |
self.single_vector_projector_dim = config.single_vector_projector_dim
|
| 164 |
self.multi_vector_projector_dim = config.multi_vector_projector_dim
|
| 165 |
|
|
|
|
| 171 |
) -> torch.Tensor:
|
| 172 |
if "pixel_values" in kwargs:
|
| 173 |
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
|
| 174 |
+
kwargs["pixel_values"] = torch.cat(
|
| 175 |
+
[pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0
|
| 176 |
+
)
|
| 177 |
|
| 178 |
position_ids, rope_deltas = super().get_rope_index( # type: ignore
|
| 179 |
input_ids=input_ids,
|
|
|
|
| 181 |
attention_mask=attention_mask,
|
| 182 |
)
|
| 183 |
|
| 184 |
+
kwargs["output_hidden_states"] = True
|
| 185 |
|
| 186 |
outputs = super().forward(
|
| 187 |
input_ids,
|
|
|
|
| 225 |
Project the hidden states to single-vector embeddings.
|
| 226 |
"""
|
| 227 |
if self._input_has_image(input_ids[0]): # got document image
|
| 228 |
+
img_start_pos = torch.where(
|
| 229 |
+
input_ids[0] == self.config.vision_start_token_id
|
| 230 |
+
)[0][0]
|
| 231 |
+
img_end_pos = torch.where(input_ids[0] == self.config.vision_end_token_id)[
|
| 232 |
+
0
|
| 233 |
+
][0]
|
| 234 |
+
pooled_output = (
|
| 235 |
+
hidden_states[0][img_start_pos : img_end_pos + 1]
|
| 236 |
+
.mean(dim=0)
|
| 237 |
+
.unsqueeze(0)
|
| 238 |
+
)
|
| 239 |
|
| 240 |
else: # got query text
|
| 241 |
+
pooled_output = torch.sum(
|
| 242 |
+
hidden_states * attention_mask.unsqueeze(-1), dim=1
|
| 243 |
+
) / torch.sum(attention_mask, dim=1, keepdim=True)
|
| 244 |
single_vec_emb = self.single_vector_projector(pooled_output)
|
| 245 |
return torch.nn.functional.normalize(single_vec_emb, dim=-1)
|
| 246 |
|
|
|
|
| 282 |
) # (batch_size, seq_length, hidden_size)
|
| 283 |
|
| 284 |
# Compute the embeddings
|
| 285 |
+
single_vec_emb = self.project_to_single_vector_embeddings(
|
| 286 |
+
hidden_states, attention_mask, input_ids=input_ids
|
| 287 |
+
)
|
| 288 |
+
multi_vec_emb = self.project_to_multi_vector_embeddings(
|
| 289 |
+
hidden_states, attention_mask
|
| 290 |
+
)
|
| 291 |
|
| 292 |
return JinaEmbeddingsV4ModelOutput(
|
| 293 |
+
vlm_last_hidden_states=(
|
| 294 |
+
hidden_states if output_vlm_last_hidden_states else None
|
| 295 |
+
),
|
| 296 |
single_vec_emb=single_vec_emb,
|
| 297 |
multi_vec_emb=multi_vec_emb,
|
| 298 |
)
|
| 299 |
+
|
| 300 |
def _process_batches(
|
| 301 |
self,
|
| 302 |
data: List[Union[str, Image.Image]],
|
|
|
|
| 324 |
embeddings = embeddings.single_vec_emb
|
| 325 |
else:
|
| 326 |
embeddings = embeddings.multi_vec_emb
|
| 327 |
+
results.append(
|
| 328 |
+
embeddings.cpu()
|
| 329 |
+
if return_numpy
|
| 330 |
+
else list(torch.unbind(embeddings))
|
| 331 |
+
)
|
| 332 |
if return_numpy:
|
| 333 |
return np.concatenate([result.numpy() for result in results], axis=0)
|
| 334 |
return [item for sublist in results for item in sublist]
|
|
|
|
| 342 |
desc: Optional[str] = None,
|
| 343 |
**kwargs,
|
| 344 |
) -> List[torch.Tensor]:
|
| 345 |
+
processor_fn = partial(
|
| 346 |
+
self.processor.process_texts, max_length=max_length, prefix="Query"
|
| 347 |
+
)
|
| 348 |
return self._process_batches(
|
| 349 |
data=queries,
|
| 350 |
processor_fn=processor_fn,
|
|
|
|
| 371 |
**kwargs,
|
| 372 |
)
|
| 373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
@classmethod
|
| 375 |
def from_pretrained(
|
| 376 |
cls,
|
|
|
|
| 380 |
):
|
| 381 |
if "torch_dtype" not in kwargs:
|
| 382 |
kwargs["torch_dtype"] = "auto"
|
| 383 |
+
|
| 384 |
+
task = kwargs.pop("task", TaskType.retrieval)
|
| 385 |
+
|
| 386 |
+
# Get the base model first
|
| 387 |
+
base_model = super().from_pretrained(
|
| 388 |
+
pretrained_model_name_or_path, *args, **kwargs
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# Configure adapter directory
|
| 392 |
+
if os.path.isdir(base_model.name_or_path):
|
| 393 |
+
adapter_dir = os.path.join(base_model.name_or_path, "adapters")
|
| 394 |
else:
|
| 395 |
adapter_cache_path = snapshot_download(
|
| 396 |
+
repo_id=base_model.name_or_path, allow_patterns=["adapters/*"]
|
|
|
|
| 397 |
)
|
| 398 |
+
adapter_dir = os.path.join(adapter_cache_path, "adapters")
|
| 399 |
+
|
| 400 |
+
# Store adapter directory for later use with set_task
|
| 401 |
+
base_model.adapter_dir = adapter_dir
|
| 402 |
+
|
| 403 |
+
# Create the PEFT model with the requested task adapter
|
| 404 |
+
peft_model = PeftModel.from_pretrained(
|
| 405 |
+
base_model, os.path.join(adapter_dir, task)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Add set_task method to the PEFT model instance
|
| 409 |
+
def set_task_method(self, task_name: Union[str, TaskType]):
|
| 410 |
+
"""
|
| 411 |
+
Set the task adapter for the model.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
task_name (Union[str, TaskType]): The task name. Must be one of TaskType values or
|
| 415 |
+
one of ['retrieval', 'text-matching', 'code']
|
| 416 |
+
"""
|
| 417 |
+
if isinstance(task_name, str):
|
| 418 |
+
try:
|
| 419 |
+
task_name = TaskType(task_name)
|
| 420 |
+
except ValueError:
|
| 421 |
+
valid_tasks = [t.value for t in TaskType]
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"Invalid task: {task_name}. Must be one of {valid_tasks}"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
adapter_path = os.path.join(self.adapter_dir, task_name.value)
|
| 427 |
+
hotswap_adapter(self, adapter_path, adapter_name="default")
|
| 428 |
+
|
| 429 |
+
# Bind the method to the instance
|
| 430 |
+
peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
|
| 431 |
+
|
| 432 |
+
return peft_model
|