import torch import math from vllm import LLM, SamplingParams from utils import prompt_template, truncate class ERank_vLLM: def __init__(self, model_name_or_path: str): """ Initializes the ERank_vLLM reranker. Args: model_name_or_path (str): The name or path of the model to be loaded. This can be a Hugging Face model ID or a local path. """ num_gpu = torch.cuda.device_count() self.ranker = LLM( model=model_name_or_path, tensor_parallel_size=num_gpu, gpu_memory_utilization=0.95, enable_prefix_caching=True ) self.tokenizer = self.ranker.get_tokenizer() self.sampling_params = SamplingParams( temperature=0, max_tokens=4096, logprobs=20 ) def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list: """ Reranks a list of documents based on a query and a specific instruction. Args: query (str): The search query provided by the user. docs (list): A list of dictionaries, where each dictionary represents a document and must contain a "content" key. instruction (str): The instruction for the model, guiding it on how to evaluate the documents. truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None. Returns: list: A new list of document dictionaries, sorted by their "rank_score" in descending order. """ # prepare messages messages = [ [{ "role": "user", "content": prompt_template.format( query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query, doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"], instruction=instruction ) }] for doc in docs ] # LLM generate outputs = self.ranker.chat(messages, self.sampling_params) # extract and organize results results = [] for doc, output in zip(docs, outputs): # extract the answer and its probability cur = "" answer = "" is_ans = False prob = 1.0 for each in output.outputs[0].logprobs[-10:]: _, detail = next(iter(each.items())) token = detail.decoded_token logprob = detail.logprob if is_ans and token.isdigit(): answer += token prob *= math.exp(logprob) else: cur += token if cur.endswith(""): is_ans = True # in case the answer is not a digit or exceeds 10 try: answer = int(answer) assert answer <= 10 except: answer = -1 # append to the final results results.append({ **doc, "rank_score": answer * prob }) # sort the reranking results for the query results.sort(key=lambda x:x["rank_score"], reverse=True) return results