File size: 5,800 Bytes
ab15147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
from sentence_transformers import CrossEncoder as _CE
import math
from typing import cast
import types
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.tokenization_utils_base import BatchEncoding
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
MODEL_PATH = "zeroentropy/ze-rerank-small-v0.3.0"
PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
def format_pointwise_datapoints(
tokenizer: PreTrainedTokenizerFast,
query_documents: list[tuple[str, str]],
) -> BatchEncoding:
input_texts: list[str] = []
for query, document in query_documents:
system_prompt = f"""
{query}
""".strip()
user_message = f"""
{document}
""".strip()
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
assert isinstance(input_text, str)
input_texts.append(input_text)
batch_inputs = tokenizer(
input_texts,
padding=True,
return_tensors="pt",
)
return batch_inputs
def load_model(
device: torch.device | None = None,
) -> tuple[
PreTrainedTokenizerFast,
LlamaForCausalLM
| Gemma3ForConditionalGeneration
| Gemma3ForCausalLM
| Qwen3ForCausalLM,
]:
if device is None:
device = torch.device("cpu")
config = AutoConfig.from_pretrained(MODEL_PATH)
assert isinstance(config, PretrainedConfig)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype="auto",
quantization_config=None,
device_map={"": device},
)
if config.model_type == "llama":
model.config.attn_implementation = "flash_attention_2"
print(f"Model Type: {config.model_type}")
assert isinstance(
model,
LlamaForCausalLM
| Gemma3ForConditionalGeneration
| Gemma3ForCausalLM
| Qwen3ForCausalLM,
)
tokenizer = cast(
AutoTokenizer,
AutoTokenizer.from_pretrained(
MODEL_PATH,
padding_side="right",
),
)
assert isinstance(tokenizer, PreTrainedTokenizerFast)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
def predict(self, query_documents: list[tuple[str, str]]) -> list[float]:
if not hasattr(self, "inner_model"):
self.inner_tokenizer, self.inner_model = load_model(torch.device("cuda"))
self.inner_model.gradient_checkpointing_enable()
self.inner_model.eval()
self.inner_yes_token_id = self.inner_tokenizer.encode("Yes", add_special_tokens=False)[0]
print("patched")
model = self.inner_model
tokenizer = self.inner_tokenizer
query_documents = [
(query[:2_000], document[:10_000]) for query, document in query_documents
]
# Sort
permutation = list(range(len(query_documents)))
permutation.sort(key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1]))
query_documents = [query_documents[i] for i in permutation]
device = torch.device("cuda")
# Extract document batches from this line of datapoints
max_length = 0
batches: list[list[tuple[str, str]]] = []
for query, document in query_documents:
if (
len(batches) == 0
or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document))
> PER_DEVICE_BATCH_SIZE_TOKENS
):
batches.append([])
max_length = 0
batches[-1].append((query, document))
max_length = max(max_length, 20 + len(query) + len(document))
# Inference all of the document batches
all_logits: list[float] = []
for batch in batches:
batch_inputs = format_pointwise_datapoints(
tokenizer,
batch,
)
batch_inputs = batch_inputs.to(device)
try:
outputs = model(**batch_inputs, use_cache=False)
except torch.OutOfMemoryError:
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
torch.cuda.empty_cache()
print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}")
outputs = model(**batch_inputs, use_cache=False)
# Extract the logits
logits = cast(torch.Tensor, outputs.logits)
attention_mask = cast(torch.Tensor, batch_inputs.attention_mask)
last_positions = attention_mask.sum(dim=1) - 1
batch_size = logits.shape[0]
batch_indices = torch.arange(batch_size, device=device)
last_logits = logits[batch_indices, last_positions]
yes_logits = last_logits[:, self.inner_yes_token_id]
all_logits.extend([float(logit) / 5.0 for logit in yes_logits])
def sigmoid(x: float) -> float:
return 1 / (1 + math.exp(-x))
scores = [sigmoid(logit) for logit in all_logits]
# Unsort by indices
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
return scores
_CE.predict = predict
from transformers import Qwen3Config
ZEConfig = Qwen3Config
|