|
from sentence_transformers import CrossEncoder as _CE |
|
|
|
import math |
|
from typing import cast |
|
import types |
|
|
|
import torch |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.models.auto.configuration_auto import AutoConfig |
|
from transformers.models.auto.modeling_auto import AutoModelForCausalLM |
|
from transformers.models.auto.tokenization_auto import AutoTokenizer |
|
from transformers.models.gemma3.modeling_gemma3 import ( |
|
Gemma3ForCausalLM, |
|
Gemma3ForConditionalGeneration, |
|
) |
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM |
|
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
|
|
|
|
|
|
|
MODEL_PATH = "zeroentropy/ze-rerank-small-v0.3.0" |
|
PER_DEVICE_BATCH_SIZE_TOKENS = 15_000 |
|
|
|
|
|
def format_pointwise_datapoints( |
|
tokenizer: PreTrainedTokenizerFast, |
|
query_documents: list[tuple[str, str]], |
|
) -> BatchEncoding: |
|
input_texts: list[str] = [] |
|
for query, document in query_documents: |
|
system_prompt = f""" |
|
{query} |
|
""".strip() |
|
user_message = f""" |
|
{document} |
|
""".strip() |
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_message}, |
|
] |
|
input_text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
) |
|
assert isinstance(input_text, str) |
|
input_texts.append(input_text) |
|
|
|
batch_inputs = tokenizer( |
|
input_texts, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
return batch_inputs |
|
|
|
|
|
def load_model( |
|
device: torch.device | None = None, |
|
) -> tuple[ |
|
PreTrainedTokenizerFast, |
|
LlamaForCausalLM |
|
| Gemma3ForConditionalGeneration |
|
| Gemma3ForCausalLM |
|
| Qwen3ForCausalLM, |
|
]: |
|
if device is None: |
|
device = torch.device("cpu") |
|
|
|
config = AutoConfig.from_pretrained(MODEL_PATH) |
|
assert isinstance(config, PretrainedConfig) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype="auto", |
|
quantization_config=None, |
|
device_map={"": device}, |
|
) |
|
if config.model_type == "llama": |
|
model.config.attn_implementation = "flash_attention_2" |
|
print(f"Model Type: {config.model_type}") |
|
assert isinstance( |
|
model, |
|
LlamaForCausalLM |
|
| Gemma3ForConditionalGeneration |
|
| Gemma3ForCausalLM |
|
| Qwen3ForCausalLM, |
|
) |
|
|
|
tokenizer = cast( |
|
AutoTokenizer, |
|
AutoTokenizer.from_pretrained( |
|
MODEL_PATH, |
|
padding_side="right", |
|
), |
|
) |
|
assert isinstance(tokenizer, PreTrainedTokenizerFast) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
return tokenizer, model |
|
|
|
|
|
def predict(self, query_documents: list[tuple[str, str]]) -> list[float]: |
|
if not hasattr(self, "inner_model"): |
|
self.inner_tokenizer, self.inner_model = load_model(torch.device("cuda")) |
|
self.inner_model.gradient_checkpointing_enable() |
|
self.inner_model.eval() |
|
self.inner_yes_token_id = self.inner_tokenizer.encode("Yes", add_special_tokens=False)[0] |
|
print("patched") |
|
|
|
model = self.inner_model |
|
tokenizer = self.inner_tokenizer |
|
|
|
query_documents = [ |
|
(query[:2_000], document[:10_000]) for query, document in query_documents |
|
] |
|
|
|
permutation = list(range(len(query_documents))) |
|
permutation.sort(key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])) |
|
query_documents = [query_documents[i] for i in permutation] |
|
|
|
device = torch.device("cuda") |
|
|
|
|
|
max_length = 0 |
|
batches: list[list[tuple[str, str]]] = [] |
|
for query, document in query_documents: |
|
if ( |
|
len(batches) == 0 |
|
or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document)) |
|
> PER_DEVICE_BATCH_SIZE_TOKENS |
|
): |
|
batches.append([]) |
|
max_length = 0 |
|
|
|
batches[-1].append((query, document)) |
|
max_length = max(max_length, 20 + len(query) + len(document)) |
|
|
|
|
|
all_logits: list[float] = [] |
|
for batch in batches: |
|
batch_inputs = format_pointwise_datapoints( |
|
tokenizer, |
|
batch, |
|
) |
|
|
|
batch_inputs = batch_inputs.to(device) |
|
|
|
try: |
|
outputs = model(**batch_inputs, use_cache=False) |
|
except torch.OutOfMemoryError: |
|
print(f"GPU OOM! {torch.cuda.memory_reserved()}") |
|
torch.cuda.empty_cache() |
|
print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}") |
|
outputs = model(**batch_inputs, use_cache=False) |
|
|
|
|
|
logits = cast(torch.Tensor, outputs.logits) |
|
attention_mask = cast(torch.Tensor, batch_inputs.attention_mask) |
|
last_positions = attention_mask.sum(dim=1) - 1 |
|
|
|
batch_size = logits.shape[0] |
|
batch_indices = torch.arange(batch_size, device=device) |
|
last_logits = logits[batch_indices, last_positions] |
|
|
|
yes_logits = last_logits[:, self.inner_yes_token_id] |
|
all_logits.extend([float(logit) / 5.0 for logit in yes_logits]) |
|
|
|
def sigmoid(x: float) -> float: |
|
return 1 / (1 + math.exp(-x)) |
|
|
|
scores = [sigmoid(logit) for logit in all_logits] |
|
|
|
|
|
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))] |
|
|
|
return scores |
|
|
|
|
|
_CE.predict = predict |
|
|
|
from transformers import Qwen3Config |
|
|
|
ZEConfig = Qwen3Config |
|
|