diff --git "a/llmlingua/prompt_compressor.py" "b/llmlingua/prompt_compressor.py"
new file mode 100644--- /dev/null
+++ "b/llmlingua/prompt_compressor.py"
@@ -0,0 +1,2412 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import bisect
+import re
+from collections import defaultdict
+from typing import List
+
+import numpy as np
+import torch
+
+import nltk
+import tiktoken
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoModelForTokenClassification,
+ AutoTokenizer,
+)
+import torch.nn.functional as F
+import string
+import copy
+from torch.utils.data import DataLoader
+
+from .utils import TokenClfDataset, seed_everything, is_begin_of_new_word, replace_added_token, get_pure_token
+
+
+class PromptCompressor:
+ """
+ PromptCompressor is designed for compressing prompts based on a given language model.
+
+ This class initializes with the language model and its configuration, preparing it for prompt compression tasks.
+ The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing.
+ Users can specify different model names and configurations as needed for their particular use case.The architecture is
+ based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu,
+ Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models."
+ arXiv preprint arXiv:2310.05736 (2023).
+
+ Args:
+ model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf".
+ device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda".
+ model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary.
+ open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary.
+ use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper
+ "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression".
+ Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang.
+ "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:,
+ Default is True.
+ llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is
+ {
+ "max_batch_size": 50,
+ "max_force_token": 100, # max number of the tokens which will be forcely preserved
+ }
+ Example:
+ >>> compress_method = PromptCompressor(model_name="xxx/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, )
+ >>> context = ["This is the first context sentence.", "Here is another context sentence."]
+ >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5)
+ >>> print(result["compressed_prompt"])
+ # This will print the compressed version of the context.
+
+ Note:
+ The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models.
+ """
+
+ def __init__(
+ self,
+ model_name: str = "NousResearch/Llama-2-7b-hf",
+ device_map: str = "cuda",
+ model_config: dict = {},
+ open_api_config: dict = {},
+ use_llmlingua2: bool = True,
+ llmlingua2_config: dict = {},
+ ):
+ self.model_name = model_name
+ self.use_llmlingua2 = use_llmlingua2
+ self.retrieval_model = None
+ self.retrieval_model_name = None
+ self.open_api_config = open_api_config
+ self.cache_bos_num = 10
+ self.prefix_bos_num = 100
+ self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+
+ self.load_model(model_name, device_map, model_config)
+ if use_llmlingua2:
+ self.init_llmlingua2(**llmlingua2_config)
+
+ def init_llmlingua2(
+ self,
+ max_batch_size: int = 50,
+ max_force_token: int = 100,
+ ):
+
+ seed_everything(42)
+ self.max_batch_size = max_batch_size
+ self.max_seq_len = 512
+ self.max_force_token = max_force_token
+ self.special_tokens = set(self.tokenizer.special_tokens_map.values())
+
+ self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)]
+ self.tokenizer.add_special_tokens(
+ {"additional_special_tokens": self.added_tokens}
+ )
+ self.model.resize_token_embeddings(len(self.tokenizer))
+
+ def load_model(
+ self, model_name: str, device_map: str = "cuda", model_config: dict = {}
+ ):
+ trust_remote_code = model_config.get("trust_remote_code", True)
+ if "trust_remote_code" not in model_config:
+ model_config["trust_remote_code"] = trust_remote_code
+ config = AutoConfig.from_pretrained(model_name, **model_config)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config)
+ if model_config.get("pad_to_left", True):
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token_id = (
+ config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
+ )
+ MODEL_CLASS = (
+ AutoModelForTokenClassification
+ if any("ForTokenClassification" in ar for ar in config.architectures)
+ else AutoModelForCausalLM
+ )
+ self.device = (
+ device_map
+ if any(key in device_map for key in ["cuda", "cpu", "mps"])
+ else "cuda"
+ )
+ if "cuda" in device_map or "cpu" in device_map:
+ model = MODEL_CLASS.from_pretrained(
+ model_name,
+ torch_dtype=model_config.get(
+ "torch_dtype", "auto" if device_map == "cuda" else torch.float32
+ ),
+ device_map=device_map,
+ config=config,
+ ignore_mismatched_sizes=True,
+ **model_config,
+ )
+ else:
+ model = MODEL_CLASS.from_pretrained(
+ model_name,
+ device_map=device_map,
+ torch_dtype=model_config.get("torch_dtype", "auto"),
+ pad_token_id=tokenizer.pad_token_id,
+ **model_config,
+ )
+ self.tokenizer = tokenizer
+ self.model = model
+ self.context_idxs = []
+ self.max_position_embeddings = config.max_position_embeddings
+
+ def get_ppl(
+ self,
+ text: str,
+ granularity: str = "sentence",
+ input_ids=None,
+ attention_mask=None,
+ past_key_values=None,
+ return_kv=False,
+ end=None,
+ condition_mode: str = "none",
+ condition_pos_id: int = 0,
+ ):
+ if input_ids is None:
+ tokenized_text = self.tokenizer(text, return_tensors="pt")
+ input_ids = tokenized_text["input_ids"].to(self.device)
+ attention_mask = tokenized_text["attention_mask"].to(self.device)
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+ else:
+ past_length = 0
+ if end is None:
+ end = input_ids.shape[1]
+ end = min(end, past_length + self.max_position_embeddings)
+ with torch.no_grad():
+ response = self.model(
+ input_ids[:, past_length:end],
+ attention_mask=attention_mask[:, :end],
+ past_key_values=past_key_values,
+ use_cache=True,
+ )
+ past_key_values = response.past_key_values
+
+ shift_logits = response.logits[..., :-1, :].contiguous()
+ shift_labels = input_ids[..., past_length + 1 : end].contiguous()
+ # Flatten the tokens
+ active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
+ active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
+ active_labels = shift_labels.view(-1)[active]
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
+ loss = loss_fct(active_logits, active_labels)
+ if condition_mode == "before":
+ loss = loss[:condition_pos_id]
+ elif condition_mode == "after":
+ loss = loss[condition_pos_id:]
+ res = loss.mean() if granularity == "sentence" else loss
+ return (res, past_key_values) if return_kv else res
+
+ def __call__(self, *args, **kwargs):
+ return self.compress_prompt(*args, **kwargs)
+
+ def structured_compress_prompt(
+ self,
+ context: List[str],
+ instruction: str = "",
+ question: str = "",
+ rate: float = 0.5,
+ target_token: float = -1,
+ iterative_size: int = 200,
+ force_context_ids: List[int] = None,
+ force_context_number: int = None,
+ use_sentence_level_filter: bool = False,
+ use_context_level_filter: bool = True,
+ use_token_level_filter: bool = True,
+ keep_split: bool = False,
+ keep_first_sentence: int = 0,
+ keep_last_sentence: int = 0,
+ keep_sentence_number: int = 0,
+ high_priority_bonus: int = 100,
+ context_budget: str = "+100",
+ token_budget_ratio: float = 1.4,
+ condition_in_question: str = "none",
+ reorder_context: str = "original",
+ dynamic_context_compression_ratio: float = 0.0,
+ condition_compare: bool = False,
+ add_instruction: bool = False,
+ rank_method: str = "llmlingua",
+ concate_question: bool = True,
+ ):
+ """
+ Compresses the given prompt context based on a specified structure.
+
+ Each element of context should be segmented using one or more non-nested '' tags.
+ Each '' tag can include optional parameters 'rate' and 'compress' (e.g., ''),
+ indicating the compression rate for that segment. Default values are 'rate=rate' and 'compress=True'.
+ When 'compress' is set to False, it overrides the 'rate' parameter, resulting in no compression for that segment.
+
+ Args:
+ context (List[str]): List of context strings divided by '' tags with optional compression settings.
+ instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
+ question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
+ rate (float, optional): The compression rate is defined the same as in paper "Language Modeling Is Compression".
+ Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, Elliot Catt, Tim Genewein, Christopher Mattern,
+ Jordi Grau-Moya et al. "Language modeling is compression." arXiv preprint arXiv:2309.10668 (2023):
+ .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
+ Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
+ fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
+ to 1.0, representing the target compression rate. ``rate``, is applicable only within the context-level filter
+ and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global rate.
+ However, for segments where no specific rate is defined, the global rate serves as the default value. The final
+ compression rate of the entire text is a composite result of multiple compression rates applied across different sections.
+ target_token (float, optional): The global maximum number of tokens to be achieved. Default is -1, indicating no
+ specific target. The actual number of tokens after compression should generally be less than the specified target_token,
+ but there can be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+ the sole criterion, overriding the ``rate``. ``target_token``, is applicable only within the context-level
+ filter and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global target token.
+ However, for segments where no specific rate is defined, the global rate calculated from global target token serves
+ as the default value. The final target token of the entire text is a composite result of multiple compression rates
+ applied across different sections.
+ iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
+ force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+ force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
+ use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
+ use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
+ use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
+ keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
+ keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
+ keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
+ keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
+ high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
+ context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
+ token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
+ condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
+ reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
+ dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
+ condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
+ add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
+ rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
+ concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
+
+ Returns:
+ dict: A dictionary containing:
+ - "compressed_prompt" (str): The resulting compressed prompt.
+ - "origin_tokens" (int): The original number of tokens in the input.
+ - "compressed_tokens" (int): The number of tokens in the compressed output.
+ - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
+ - "rate" (str): The compression rate achieved, in a human-readable format.
+ - "saving" (str): Estimated savings in GPT-4 token usage.
+ """
+ if not context:
+ context = [" "]
+ if isinstance(context, str):
+ context = [context]
+ context = [
+ self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids)
+ for c in context
+ ]
+ context_tokens_length = [self.get_token_length(c) for c in context]
+ instruction_tokens_length, question_tokens_length = self.get_token_length(
+ instruction
+ ), self.get_token_length(question)
+ if target_token == -1:
+ target_token = (
+ (
+ instruction_tokens_length
+ + question_tokens_length
+ + sum(context_tokens_length)
+ )
+ * rate
+ - instruction_tokens_length
+ - (question_tokens_length if concate_question else 0)
+ )
+ else:
+ rate = target_token / sum(context_tokens_length)
+ (
+ context,
+ context_segs,
+ context_segs_rate,
+ context_segs_compress,
+ ) = self.segment_structured_context(context, rate)
+ return self.compress_prompt(
+ context,
+ instruction,
+ question,
+ rate,
+ target_token,
+ iterative_size,
+ force_context_ids,
+ force_context_number,
+ use_sentence_level_filter,
+ use_context_level_filter,
+ use_token_level_filter,
+ keep_split,
+ keep_first_sentence,
+ keep_last_sentence,
+ keep_sentence_number,
+ high_priority_bonus,
+ context_budget,
+ token_budget_ratio,
+ condition_in_question,
+ reorder_context,
+ dynamic_context_compression_ratio,
+ condition_compare,
+ add_instruction,
+ rank_method,
+ concate_question,
+ context_segs=context_segs,
+ context_segs_rate=context_segs_rate,
+ context_segs_compress=context_segs_compress,
+ )
+
+ def compress_prompt(
+ self,
+ context: List[str],
+ instruction: str = "",
+ question: str = "",
+ rate: float = 0.5,
+ target_token: float = -1,
+ iterative_size: int = 200,
+ force_context_ids: List[int] = None,
+ force_context_number: int = None,
+ use_sentence_level_filter: bool = False,
+ use_context_level_filter: bool = True,
+ use_token_level_filter: bool = True,
+ keep_split: bool = False,
+ keep_first_sentence: int = 0,
+ keep_last_sentence: int = 0,
+ keep_sentence_number: int = 0,
+ high_priority_bonus: int = 100,
+ context_budget: str = "+100",
+ token_budget_ratio: float = 1.4,
+ condition_in_question: str = "none",
+ reorder_context: str = "original",
+ dynamic_context_compression_ratio: float = 0.0,
+ condition_compare: bool = False,
+ add_instruction: bool = False,
+ rank_method: str = "llmlingua",
+ concate_question: bool = True,
+ context_segs: List[str] = None,
+ context_segs_rate: List[float] = None,
+ context_segs_compress: List[bool] = None,
+ target_context: int = -1,
+ context_level_rate: float = 1.0,
+ context_level_target_token: int = -1,
+ return_word_label: bool = False,
+ word_sep: str = "\t\t|\t\t",
+ label_sep: str = " ",
+ token_to_word: str = "mean",
+ force_tokens: List[str] = [],
+ force_reserve_digit: bool = False,
+ drop_consecutive: bool = False,
+ chunk_end_tokens: List[str] = [".", "\n"],
+ ):
+ """
+ Compresses the given context.
+
+ Args:
+ context (List[str]): List of context strings that form the basis of the prompt.
+ instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
+ question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
+ rate (float, optional): The maximum compression rate target to be achieved. The compression rate is defined
+ the same as in paper "Language Modeling Is Compression". Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne,
+ Elliot Catt, Tim Genewein, Christopher Mattern, Jordi Grau-Moya et al. "Language modeling is compression."
+ arXiv preprint arXiv:2309.10668 (2023):
+ .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
+ Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
+ fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
+ to 1.0, representing the target compression rate.
+ target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
+ The actual number of tokens after compression should generally be less than the specified target_token, but there can
+ be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+ the sole criterion, overriding the ``rate``.
+ iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
+ force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+ force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
+ use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
+ use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
+ use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
+ keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
+ keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
+ keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
+ keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
+ high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
+ context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
+ token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
+ condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
+ reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
+ dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
+ condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
+ add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
+ rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
+ concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
+
+ target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
+ context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
+ context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
+ Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
+ force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+ return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
+ word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
+ label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ".
+ token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
+ force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+ force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+ drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt.
+ Default is False.
+ chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"],
+ Returns:
+ dict: A dictionary containing:
+ - "compressed_prompt" (str): The resulting compressed prompt.
+ - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2.
+ - "fn_labeled_original_prompt" (str): original words along with their labels
+ indicating whether to reserve in compressed prompt, in the format (word label_sep label)
+ Only used in llmlingua2 when return_word_label = True.
+ - "origin_tokens" (int): The original number of tokens in the input.
+ - "compressed_tokens" (int): The number of tokens in the compressed output.
+ - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
+ - "rate" (str): The compression rate achieved, in a human-readable format.
+ - "saving" (str): Estimated savings in GPT-4 token usage.
+ """
+ if self.use_llmlingua2:
+ return self.compress_prompt_llmlingua2(
+ context,
+ rate=rate,
+ target_token=target_token,
+ use_context_level_filter=use_context_level_filter,
+ use_token_level_filter=use_token_level_filter,
+ target_context=target_context,
+ context_level_rate=context_level_rate,
+ context_level_target_token=context_level_target_token,
+ force_context_ids=force_context_ids,
+ return_word_label=return_word_label,
+ word_sep=word_sep,
+ label_sep=label_sep,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ chunk_end_tokens=chunk_end_tokens,
+ )
+ assert (
+ rate <= 1.0
+ ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
+
+ if not context:
+ context = [" "]
+ if isinstance(context, str):
+ context = [context]
+ assert not (
+ rank_method == "longllmlingua" and not question
+ ), "In the LongLLMLingua, it is necessary to set a question."
+ if condition_compare and "_condition" not in condition_in_question:
+ condition_in_question += "_condition"
+ if rank_method == "longllmlingua":
+ if condition_in_question == "none":
+ condition_in_question = "after"
+ elif rank_method == "llmlingua":
+ condition_in_question = (
+ "none"
+ if "_condition" not in condition_in_question
+ else "none_condition"
+ )
+ origin_tokens = len(
+ self.oai_tokenizer.encode(
+ "\n\n".join([instruction] + context + [question]).strip()
+ )
+ )
+ context_tokens_length = [self.get_token_length(c) for c in context]
+ instruction_tokens_length, question_tokens_length = self.get_token_length(
+ instruction
+ ), self.get_token_length(question)
+ if target_token == -1:
+ target_token = (
+ (
+ instruction_tokens_length
+ + question_tokens_length
+ + sum(context_tokens_length)
+ )
+ * rate
+ - instruction_tokens_length
+ - (question_tokens_length if concate_question else 0)
+ )
+ condition_flag = "_condition" in condition_in_question
+ condition_in_question = condition_in_question.replace("_condition", "")
+
+ if len(context) > 1 and use_context_level_filter:
+ context, dynamic_ratio, context_used = self.control_context_budget(
+ context,
+ context_tokens_length,
+ target_token,
+ force_context_ids,
+ force_context_number,
+ question,
+ condition_in_question,
+ reorder_context=reorder_context,
+ dynamic_context_compression_ratio=dynamic_context_compression_ratio,
+ rank_method=rank_method,
+ context_budget=context_budget,
+ context_segs=context_segs,
+ context_segs_rate=context_segs_rate,
+ context_segs_compress=context_segs_compress,
+ )
+ if context_segs is not None:
+ context_segs = [context_segs[idx] for idx in context_used]
+ context_segs_rate = [context_segs_rate[idx] for idx in context_used]
+ context_segs_compress = [
+ context_segs_compress[idx] for idx in context_used
+ ]
+ else:
+ dynamic_ratio = [0.0] * len(context)
+
+ segments_info = []
+ if use_sentence_level_filter:
+ context, segments_info = self.control_sentence_budget(
+ context,
+ target_token,
+ keep_first_sentence=keep_first_sentence,
+ keep_last_sentence=keep_last_sentence,
+ keep_sentence_number=keep_sentence_number,
+ high_priority_bonus=high_priority_bonus,
+ token_budget_ratio=token_budget_ratio,
+ question=question,
+ condition_in_question=condition_in_question,
+ rank_method=rank_method,
+ context_segs=context_segs,
+ context_segs_rate=context_segs_rate,
+ context_segs_compress=context_segs_compress,
+ )
+ elif context_segs is not None:
+ for context_idx in range(len(context)):
+ segments_info.append(
+ [
+ (len(seg_text), seg_rate, seg_compress)
+ for seg_text, seg_rate, seg_compress in zip(
+ context_segs[context_idx],
+ context_segs_rate[context_idx],
+ context_segs_compress[context_idx],
+ )
+ ]
+ )
+ segments_info = [
+ self.concate_segment_info(segment_info) for segment_info in segments_info
+ ]
+
+ if condition_flag:
+ prefix = question + "\n\n" + instruction if add_instruction else question
+ if (
+ self.get_token_length(prefix + "\n\n") + iterative_size * 2
+ > self.max_position_embeddings
+ ):
+ tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids
+ prefix = self.tokenizer.decode(
+ tokens[: self.prefix_bos_num]
+ + tokens[
+ len(tokens)
+ - self.max_position_embeddings
+ + 2
+ + self.prefix_bos_num
+ + 2 * iterative_size :
+ ]
+ )
+ start = self.get_prefix_length(prefix + "\n\n", context[0])
+ context = [prefix] + context
+ else:
+ start = 0
+
+ if use_token_level_filter:
+ context = self.iterative_compress_prompt(
+ context,
+ target_token,
+ iterative_size=iterative_size,
+ keep_split=keep_split,
+ start=start,
+ dynamic_ratio=dynamic_ratio,
+ condition_compare=condition_compare,
+ segments_info=segments_info,
+ )
+ compressed_prompt = (
+ self.tokenizer.batch_decode(context[0])[0]
+ .replace(" ", "")
+ .replace("", "")
+ )
+ else:
+ if condition_flag:
+ context = context[1:]
+ compressed_prompt = "\n\n".join(context)
+
+ res = []
+ if instruction:
+ res.append(instruction)
+ if compressed_prompt.strip():
+ res.append(compressed_prompt)
+ if question and concate_question:
+ res.append(question)
+
+ compressed_prompt = "\n\n".join(res)
+
+ compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt))
+ saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
+ ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens
+ rate = 1 / ratio
+ return {
+ "compressed_prompt": compressed_prompt,
+ "origin_tokens": origin_tokens,
+ "compressed_tokens": compressed_tokens,
+ "ratio": f"{ratio:.1f}x",
+ "rate": f"{rate * 100:.1f}%",
+ "saving": f", Saving ${saving:.1f} in GPT-4.",
+ }
+
+ def compress_prompt_llmlingua2(
+ self,
+ context: List[str],
+ rate: float = 0.5,
+ target_token: int = -1,
+ use_context_level_filter: bool = False,
+ use_token_level_filter: bool = True,
+ target_context: int = -1,
+ context_level_rate: float = 1.0,
+ context_level_target_token: int = -1,
+ force_context_ids: List[int] = [],
+ return_word_label: bool = False,
+ word_sep: str = "\t\t|\t\t",
+ label_sep: str = " ",
+ token_to_word: str = "mean",
+ force_tokens: List[str] = [],
+ force_reserve_digit: bool = False,
+ drop_consecutive: bool = False,
+ chunk_end_tokens: List[str] = [".", "\n"],
+ ):
+ """
+ Compresses the given context, instruction and question.
+
+ Args:
+ context (List[str]): List of context strings that form the basis of the prompt.
+ rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate
+ generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified,
+ it should be a float greater than or equal to 1.0, representing the target compression rate.
+ target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
+ The actual number of tokens after compression should generally be less than the specified target_token, but there can
+ be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+ the sole criterion, overriding the rate.
+ target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
+ Only used in the coarse-to-fine compression.
+ context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
+ Only used in the coarse-to-fine compression.
+ context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
+ Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
+ force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+ return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
+ word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
+ label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ".
+ token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
+ force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+ force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+ drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt.
+ Default is False.
+ chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"].
+ Returns:
+ dict: A dictionary containing:
+ - "compressed_prompt" (str): The resulting compressed prompt.
+ - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt.
+ - "fn_labeled_original_prompt" (str): original words along with their labels
+ indicating whether to reserve in compressed prompt, in the format (word label_sep label)
+ - "origin_tokens" (int): The original number of tokens in the input.
+ - "compressed_tokens" (int): The number of tokens in the compressed output.
+ - "ratio" (str): The compression ratio achieved, in a human-readable format.
+ - "rate" (str): The compression rate achieved, in a human-readable format.
+ - "saving" (str): Estimated savings in GPT-4 token usage.
+
+ """
+ assert len(force_tokens) <= self.max_force_token
+ token_map = {}
+ for i, t in enumerate(force_tokens):
+ if len(self.tokenizer.tokenize(t)) != 1:
+ token_map[t] = self.added_tokens[i]
+ chunk_end_tokens = copy.deepcopy(chunk_end_tokens)
+ for c in chunk_end_tokens:
+ if c in token_map:
+ chunk_end_tokens.append(token_map[c])
+ chunk_end_tokens = set(chunk_end_tokens)
+
+ if type(context) == str:
+ context = [context]
+ context = copy.deepcopy(context)
+
+ if len(context) == 1 and use_context_level_filter:
+ use_context_level_filter = False
+
+ n_original_token = 0
+ context_chunked = []
+ for i in range(len(context)):
+ n_original_token += self.get_token_length(context[i], use_oai_tokenizer=True)
+ for ori_token, new_token in token_map.items():
+ context[i] = context[i].replace(ori_token, new_token)
+ context_chunked.append(self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens))
+
+ if use_context_level_filter:
+ # want use_context_level_filter but do not specify any parameters in context level?
+ # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token
+ if (
+ target_context <= 0
+ and context_level_rate >= 1.0
+ and context_level_target_token <= 0
+ ):
+ if target_token < 0 and rate < 1.0:
+ context_level_rate = (
+ (rate + 1.0) / 2 if use_token_level_filter else rate
+ )
+ print(
+ f"set context level compression rate to {context_level_rate}."
+ )
+ if target_token >= 0:
+ context_level_target_token = (
+ target_token * 2 if use_token_level_filter else target_token
+ )
+ print(
+ f"set context level target token to {context_level_target_token}."
+ )
+
+ if target_context >= 0:
+ context_level_rate = min(target_context / len(context), 1.0)
+ # print(f'override context level compression rate to {context_level_rate} because you specified target_context = {target_context}.')
+ if context_level_target_token >= 0:
+ context_level_rate = min(
+ context_level_target_token / n_original_token, 1.0
+ )
+ # print(f'override context level compression rate to {context_level_rate} because you specified context_level_target_token = {context_level_target_token}.')
+
+ context_probs, context_words = self.__get_context_prob(
+ context_chunked,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ )
+
+ threshold = np.percentile(
+ context_probs, int(100 * (1 - context_level_rate))
+ )
+
+ reserved_context = []
+ context_label = [False] * len(context_probs)
+ for i, p in enumerate(context_probs):
+ if p >= threshold or (
+ force_context_ids is not None and i in force_context_ids
+ ):
+ reserved_context.append(context_chunked[i])
+ context_label[i] = True
+ n_reserved_token = 0
+ for chunks in reserved_context:
+ for c in chunks:
+ n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True)
+ if target_token >= 0:
+ rate = min(target_token / n_reserved_token, 1.0)
+ print(
+ f"override compression rate to {rate} because you specified target_token = {target_token}."
+ )
+
+ if use_token_level_filter:
+ compressed_context, word_list, word_label_list = self.__compress(
+ reserved_context,
+ reduce_rate=max(0, 1 - rate),
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+ else:
+ compressed_context, word_list, word_label_list = self.__compress(
+ reserved_context,
+ reduce_rate=0,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+ print(
+ "return the original text because you specify use_token_level_filter=False"
+ )
+
+ n_compressed_token = 0
+ for c in compressed_context:
+ n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
+ saving = (n_original_token - n_compressed_token) * 0.06 / 1000
+ ratio = (
+ 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
+ )
+ res = {
+ "compressed_prompt": "\n\n".join(compressed_context),
+ "compressed_prompt_list": compressed_context,
+ "origin_tokens": n_original_token,
+ "compressed_tokens": n_compressed_token,
+ "ratio": f"{ratio:.1f}x",
+ "rate": f"{1 / ratio * 100:.1f}%",
+ "saving": f", Saving ${saving:.1f} in GPT-4.",
+ }
+ if return_word_label:
+ words = []
+ labels = []
+ j = 0
+ for i in range(len(context)):
+ if context_label[i]:
+ words.extend(word_list[j])
+ labels.extend(word_label_list[j])
+ j += 1
+ else:
+ words.extend(context_words[i])
+ labels.extend([0] * len(context_words[i]))
+ word_label_lines = word_sep.join(
+ [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
+ )
+ res["fn_labeled_original_prompt"] = word_label_lines
+ return res
+
+ if target_token > 0:
+ rate = min(target_token / n_original_token, 1.0)
+ print(
+ f"override compression rate to {rate} \
+ because you specified target_token = {target_token}."
+ )
+
+ if use_token_level_filter:
+ compressed_context, word_list, word_label_list = self.__compress(
+ context_chunked,
+ reduce_rate=max(0, 1 - rate),
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+ else:
+ compressed_context, word_list, word_label_list = self.__compress(
+ context_chunked,
+ reduce_rate=0,
+ token_to_word=token_to_word,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ drop_consecutive=drop_consecutive,
+ )
+ print(
+ "return the original text because you specify use_token_level_filter=False"
+ )
+
+ n_compressed_token = 0
+ for c in compressed_context:
+ n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
+ saving = (n_original_token - n_compressed_token) * 0.06 / 1000
+ ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
+ res = {
+ "compressed_prompt": "\n\n".join(compressed_context),
+ "compressed_prompt_list": compressed_context,
+ "origin_tokens": n_original_token,
+ "compressed_tokens": n_compressed_token,
+ "ratio": f"{ratio:.1f}x",
+ "rate": f"{1 / ratio * 100:.1f}%",
+ "saving": f", Saving ${saving:.1f} in GPT-4.",
+ }
+ if return_word_label:
+ words = []
+ labels = []
+ for w_list, l_list in zip(word_list, word_label_list):
+ words.extend(w_list)
+ labels.extend(l_list)
+
+ # new_words = []
+ # new_labels = []
+ # for i in range(len(words)):
+ # word, label = words[i], labels[i]
+ # if word in string.punctuation:
+ # if labels[i-1] == 1 and label == 1 and i > 0:
+ # new_words[-1] += word
+ # else:
+ # new_words.append(word)
+ # new_labels.append(label)
+ # word_label_lines = word_sep.join([f'{word}{label_sep}{label}' for word, label in zip(new_words, new_labels)])
+
+ word_label_lines = word_sep.join(
+ [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
+ )
+ res["fn_labeled_original_prompt"] = word_label_lines
+ return res
+
+ def get_token_length(self, text: str, add_special_tokens: bool = True, use_oai_tokenizer: bool = False):
+ if use_oai_tokenizer:
+ return len(self.oai_tokenizer.encode(text))
+ else:
+ return len(
+ self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
+ )
+
+ def get_prefix_length(self, prefix: str, text: str):
+ possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
+ full_input_ids = self.tokenizer(
+ prefix + text[:100], add_special_tokens=False
+ ).input_ids
+ for i in range(possible_prefix_token, len(full_input_ids)):
+ cur_prefix = self.tokenizer.decode(full_input_ids[:i])
+ if cur_prefix == prefix:
+ break
+ assert self.tokenizer.decode(full_input_ids[i:]) == text[:100]
+ return i
+
+ def get_condition_ppl(
+ self,
+ text: str,
+ question: str,
+ condition_in_question: str = "none",
+ granularity: str = "sentence",
+ ):
+ if condition_in_question == "none":
+ return self.get_ppl(text, granularity=granularity)
+ elif condition_in_question == "before":
+ return self.get_ppl(
+ question + text,
+ granularity=granularity,
+ condition_mode="after",
+ condition_pos_id=self.get_token_length(question) - 1,
+ )
+ elif condition_in_question == "after":
+ return self.get_ppl(
+ text + question,
+ granularity=granularity,
+ condition_mode="after",
+ condition_pos_id=self.get_token_length(text) - 1,
+ )
+
+ def get_dynamic_compression_ratio(
+ self,
+ context: list,
+ target_token: float,
+ iterative_size: int,
+ dynamic_ratio: list,
+ start: int,
+ seg_info: List[List[tuple]] = None,
+ ):
+ def get_ratio(base: float, delta: float):
+ return max(min(1, base + delta), 0)
+
+ context_length = [self.get_token_length(ii, False) + 2 for ii in context]
+ if start:
+ context_length = context_length[1:]
+ tau = target_token / (sum(context_length) + 1)
+ res, idx, last, last_target = [], 0, 1, []
+ while idx < len(context_length):
+ if last + context_length[idx] >= iterative_size:
+ last_target.append(
+ (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
+ )
+ res.append(last_target)
+ last = last + context_length[idx] - iterative_size
+ if last > iterative_size:
+ k = last // iterative_size
+ res.extend(
+ [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
+ )
+ last -= k * iterative_size
+
+ last_target = (
+ [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
+ )
+ else:
+ last += context_length[idx]
+ last_target.append(
+ (context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
+ )
+ idx += 1
+ if last_target:
+ res.append(last_target)
+ return res
+
+ def get_structured_dynamic_compression_ratio(
+ self,
+ context: list,
+ iterative_size: int,
+ dynamic_ratio: list,
+ start: int,
+ seg_info: List[List[tuple]] = None,
+ ):
+ if start:
+ pure_context = context[1:]
+ else:
+ pure_context = context
+ global_dynamic_rate, global_dynamic_compress, segments = [], [], []
+ for context_idx, text in enumerate(pure_context):
+ text_seen = 0
+ for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate(
+ seg_info[context_idx]
+ ):
+ seg_text = text[text_seen : text_seen + seg_len]
+ if (
+ seg_idx == len(seg_info[context_idx]) - 1
+ and context_idx != len(pure_context) - 1
+ ):
+ seg_text += "\n\n"
+ segments.append(seg_text)
+ if seg_compress:
+ global_dynamic_rate.append(seg_rate)
+ else:
+ global_dynamic_rate.append(1.0)
+ global_dynamic_compress.append(seg_compress)
+ text_seen += seg_len
+ origin_text = "\n\n".join(pure_context)
+ assert len("".join(segments)) == len(origin_text)
+ assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress)
+
+ text_input_ids = self.tokenizer(
+ "\n\n".join(context), add_special_tokens=False
+ ).input_ids[start:]
+ assert self.tokenizer.decode(text_input_ids) == origin_text
+ dynamic_compression_ratio = self.token_segment(
+ text_input_ids,
+ iterative_size,
+ segments,
+ global_dynamic_rate,
+ global_dynamic_compress,
+ )
+ return dynamic_compression_ratio
+
+ def token_segment(
+ self,
+ text_input_ids: List[int],
+ iterative_size: int,
+ segments: List[str],
+ global_dynamic_rate: List[float],
+ global_dynamic_compress: List[bool],
+ ):
+ decode_window = 3
+ seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1
+ dynamic_compression_rate, local_compresssion_rate = [], []
+ for i in range(len(text_input_ids)):
+ if i < decode_window:
+ id_pre, id_cur = text_input_ids[:i], text_input_ids[: i + 1]
+ else:
+ id_pre, id_cur = (
+ text_input_ids[i - decode_window + 1 : i],
+ text_input_ids[i - decode_window + 1 : i + 1],
+ )
+ cur_word = self.tokenizer.decode(id_cur)[
+ len(self.tokenizer.decode(id_pre)) :
+ ]
+ cur_word_len = len(cur_word)
+ if cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen:
+ possible_rate, possible_compress = [], []
+ while (
+ cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen
+ ):
+ possible_rate.append(global_dynamic_rate[seg_idx])
+ possible_compress.append(global_dynamic_compress[seg_idx])
+ cur_word_len -= len(segments[seg_idx]) - seg_seen
+ seg_idx += 1
+ seg_seen = 0
+ if cur_word_len:
+ possible_rate.append(global_dynamic_rate[seg_idx])
+ possible_compress.append(global_dynamic_compress[seg_idx])
+ new_rate = 1.0 if False in possible_compress else min(possible_rate)
+ else:
+ new_rate = global_dynamic_rate[seg_idx]
+ if new_rate != last_rate and i - token_seen_num:
+ local_compresssion_rate.append((i - token_seen_num, last_rate))
+ token_seen_num = i
+ last_rate = new_rate
+ seg_seen += cur_word_len
+ if (i + 1) % iterative_size == 0:
+ if token_seen_num != i + 1:
+ local_compresssion_rate.append((i + 1 - token_seen_num, last_rate))
+ token_seen_num = i + 1
+ dynamic_compression_rate.append(local_compresssion_rate[:])
+ local_compresssion_rate = []
+ if token_seen_num != len(text_input_ids):
+ local_compresssion_rate.append(
+ (len(text_input_ids) - token_seen_num, last_rate)
+ )
+ if local_compresssion_rate != []:
+ dynamic_compression_rate.append(local_compresssion_rate[:])
+ return dynamic_compression_rate
+
+ def control_context_budget(
+ self,
+ context: List[str],
+ context_tokens_length: List[int],
+ target_token: float,
+ force_context_ids: List[int] = None,
+ force_context_number: int = None,
+ question: str = "",
+ condition_in_question: str = "none",
+ reorder_context: str = "original",
+ dynamic_context_compression_ratio: float = 0.0,
+ rank_method: str = "longllmlingua",
+ context_budget: str = "+100",
+ context_segs: List[List[str]] = None,
+ context_segs_rate: List[List[float]] = None,
+ context_segs_compress: List[List[bool]] = None,
+ ):
+ demostrations_sort = self.get_rank_results(
+ context,
+ question,
+ rank_method,
+ condition_in_question,
+ context_tokens_length,
+ )
+
+ if target_token < 0:
+ target_token = 100
+ target_token = eval("target_token" + context_budget)
+ res = []
+ used = force_context_ids if force_context_ids is not None else []
+ if context_segs is not None:
+ for idx, _ in enumerate(context):
+ if False in context_segs_compress[idx]:
+ used.append(idx)
+
+ self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
+ for idx, _ in demostrations_sort:
+ if idx >= len(context_tokens_length):
+ continue
+ target_token -= context_tokens_length[idx]
+ if idx not in used:
+ used.append(idx)
+ if target_token < 0 or (
+ force_context_number is not None and len(res) >= force_context_number
+ ):
+ break
+ original_used = used
+ if reorder_context == "original":
+ used = sorted(used)
+ elif reorder_context == "two_stage":
+ l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
+ _ for idx, _ in enumerate(used) if idx % 2 == 1
+ ]
+ used = l + r[::-1]
+
+ if dynamic_context_compression_ratio > 0:
+ N = len(used)
+ dynamic_ratio = [
+ i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
+ for i in range(-(N - 1), N, 2)
+ ][::-1]
+ dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
+ dynamic_ratio = [dynamic_ratio_map[i] for i in used]
+ else:
+ dynamic_ratio = [0.0] * len(used)
+
+ res = [context[idx] for idx in used if idx < len(context)]
+ return res, dynamic_ratio, used
+
+ def control_sentence_budget(
+ self,
+ context: List[str],
+ target_token: float,
+ keep_first_sentence: int = 0,
+ keep_last_sentence: int = 0,
+ keep_sentence_number: int = 0,
+ high_priority_bonus: int = 100,
+ token_budget_ratio: float = 1.4,
+ question: str = "",
+ condition_in_question: str = "none",
+ rank_method: str = "longllmlingua",
+ context_segs: List[List[str]] = None,
+ context_segs_rate: List[List[float]] = None,
+ context_segs_compress: List[List[bool]] = None,
+ ):
+ def keep_sentence(dem_idx: int, sent_keep: int):
+ idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
+ for idx in idxs:
+ sentence_ppl[idx] += high_priority_bonus
+
+ def sync_sentence(segments, text):
+ seg_num = len(segments)
+ new_segments = []
+ text_seen = 0
+ seg_idx, cur_seg_seen = 0, 0
+ for i, s in enumerate(text):
+ while seg_idx < seg_num and s != segments[seg_idx][cur_seg_seen]:
+ if cur_seg_seen < len(segments[seg_idx]) - 1:
+ cur_seg_seen += 1
+ continue
+ new_segments.append(text[text_seen:i])
+ text_seen = i
+ seg_idx += 1
+ cur_seg_seen = 0
+ cur_seg_seen += 1
+ if seg_idx == seg_num:
+ break
+ if cur_seg_seen == len(segments[seg_idx]):
+ new_segments.append(text[text_seen : i + 1])
+ text_seen = i + 1
+ seg_idx += 1
+ cur_seg_seen = 0
+ if text_seen < len(text):
+ new_segments.append(text[text_seen:])
+ assert len("".join(new_segments)) == len(text)
+ return new_segments
+
+ sentences = [nltk.sent_tokenize(c) for c in context]
+ dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
+ for idx_d, s in enumerate(sentences):
+ for _ in s:
+ dem_g[idx_d].add(idx)
+ s2de[idx] = idx_d
+ idx += 1
+
+ if context_segs is not None:
+ context_segs = [
+ sync_sentence(s, "".join(c)) for s, c in zip(context_segs, sentences)
+ ]
+ sen2seg_ratio = {}
+ idx = 0
+ for idx_d, sentences_each_context in enumerate(sentences):
+ segments_length = [len(s) for s in context_segs[idx_d]]
+ seg_idx, cur_seg_seen = 0, 0
+ for sentence in sentences_each_context:
+ sentence_seg_ratio = []
+ remain = len(sentence)
+ while remain:
+ if segments_length[seg_idx] - cur_seg_seen <= remain:
+ new_seg_len = segments_length[seg_idx] - cur_seg_seen
+ sentence_seg_ratio.append(
+ (
+ new_seg_len,
+ context_segs_rate[idx_d][seg_idx],
+ context_segs_compress[idx_d][seg_idx],
+ )
+ )
+ seg_idx += 1
+ cur_seg_seen = 0
+ remain -= new_seg_len
+ else:
+ sentence_seg_ratio.append(
+ (
+ remain,
+ context_segs_rate[idx_d][seg_idx],
+ context_segs_compress[idx_d][seg_idx],
+ )
+ )
+ cur_seg_seen += remain
+ remain = 0
+ sen2seg_ratio[idx] = sentence_seg_ratio
+ idx += 1
+
+ context_sentences = [s for ii in sentences for s in ii]
+ sentence_tokens_length = [
+ self.get_token_length(sentence) for sentence in context_sentences
+ ]
+ N = len(context_sentences)
+ flags = list(range(len(context_sentences)))
+ if len(sentence_tokens_length) == 1:
+ return context
+ if rank_method == "longllmlingua":
+ sentence_ppl = [
+ self.get_condition_ppl(sentence, question, condition_in_question)
+ .cpu()
+ .numpy()
+ .item()
+ for sentence in context_sentences
+ ]
+ if keep_first_sentence:
+ sentence_ppl[:keep_first_sentence] = [
+ ii + high_priority_bonus
+ for ii in sentence_ppl[:keep_first_sentence]
+ ]
+ if keep_last_sentence:
+ sentence_ppl[-keep_last_sentence:] = [
+ ii + high_priority_bonus
+ for ii in sentence_ppl[-keep_last_sentence:]
+ ]
+ if keep_sentence_number:
+ for dem_idx in range(len(sentences)):
+ keep_sentence(dem_idx, keep_sentence_number)
+ sort_direct = -1 if condition_in_question == "none" else 1
+ sent_sort = sorted(
+ enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
+ )
+ else:
+ sent_sort = self.get_rank_results(
+ context_sentences,
+ question,
+ rank_method,
+ condition_in_question,
+ [0] * len(context_sentences),
+ )
+
+ sentence_flags = [False] * N
+ if target_token < 0:
+ target_token = 100
+ target_token *= token_budget_ratio
+ res = []
+ for idx, _ in sent_sort:
+ idx = flags[idx]
+ target_token -= sentence_tokens_length[idx]
+ sentence_flags[idx] = True
+ if target_token < 0:
+ break
+
+ if context_segs is not None:
+ for idx in range(N):
+ preserved = [sen_seg_info[2] for sen_seg_info in sen2seg_ratio[idx]]
+ if False in preserved:
+ sentence_flags[idx] = True
+
+ idx = 0
+ res = []
+ new_segments_info = []
+ for s in sentences:
+ tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
+ res.append("".join(tmp))
+ if context_segs is not None:
+ segment_ratio = []
+ for ii in range(len(s)):
+ if sentence_flags[idx + ii]:
+ segment_ratio.extend(sen2seg_ratio[idx + ii])
+ new_segments_info.append(segment_ratio)
+ idx += len(s)
+ if context_segs is not None:
+ new_segments_info = [
+ self.concate_segment_info(segment_info)
+ for segment_info in new_segments_info
+ ]
+ return res, new_segments_info
+
+ def get_compressed_input(
+ self,
+ loss,
+ input_ids,
+ attention_mask,
+ end=200,
+ iterative_size=200,
+ threshold=0.5,
+ keep_flag=None,
+ split_token_id: int = 13,
+ start: int = 0,
+ self_loss=None,
+ self_input_ids=None,
+ self_attention_mask=None,
+ ):
+ if self_loss is not None:
+ need_idx = torch.concat(
+ [
+ loss[:start] > 0,
+ self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
+ loss[:1] > 0,
+ ]
+ )
+ else:
+ need_idx = torch.concat([loss > threshold, loss[:1] > 0])
+ need_idx[end:] = 1
+ need_idx[: end - iterative_size] = 1
+ loss = loss[need_idx[:-1]]
+ if self_loss is not None:
+ if need_idx.shape[0] < self_loss.shape[0] + start + 1:
+ need_idx = torch.cat(
+ [
+ need_idx,
+ torch.ones(
+ self_loss.shape[0] - need_idx.shape[0] + start + 1,
+ dtype=torch.bool,
+ ).to(need_idx.device),
+ ]
+ )
+ self_loss = self_loss[need_idx[start:-1]]
+
+ if need_idx.shape[0] < input_ids.shape[1]:
+ need_idx = torch.cat(
+ [
+ need_idx,
+ torch.ones(
+ input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
+ ).to(need_idx.device),
+ ]
+ )
+ elif need_idx.shape[0] > input_ids.shape[1]:
+ need_idx = need_idx[: input_ids.shape[1]]
+
+ if keep_flag is not None:
+ need_idx[keep_flag == 1] = 1
+ last = -1
+ if keep_flag is not None:
+ for ii in range(max(0, end - iterative_size), end):
+ if need_idx[ii] != 1:
+ continue
+ now = input_ids[0][ii].detach().cpu().item()
+ if (
+ now == split_token_id
+ and last == split_token_id
+ and keep_flag[ii].detach().cpu().item() == 0
+ ):
+ need_idx[ii] = 0
+ else:
+ last = now
+ compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
+ compressed_attention_mask = attention_mask[attention_mask == 1][
+ need_idx
+ ].unsqueeze(0)
+
+ if self_loss is not None:
+ self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
+ need_idx[start:]
+ ].unsqueeze(0)
+ self_compressed_attention_mask = self_attention_mask[
+ self_attention_mask == 1
+ ][need_idx[start:]].unsqueeze(0)
+ else:
+ self_compressed_input_ids, self_compressed_attention_mask = None, None
+ if keep_flag is not None:
+ if len(keep_flag) > len(need_idx):
+ keep_flag = torch.cat(
+ [
+ keep_flag[:start],
+ keep_flag[start : len(need_idx) + start][need_idx],
+ keep_flag[start + len(need_idx) :],
+ ]
+ )
+ else:
+ keep_flag = keep_flag[need_idx]
+ end -= (need_idx[:end] == 0).sum()
+ return (
+ compressed_input_ids,
+ compressed_attention_mask,
+ keep_flag,
+ end,
+ loss,
+ self_loss,
+ self_compressed_input_ids,
+ self_compressed_attention_mask,
+ )
+
+ def get_estimate_threshold_base_distribution(
+ self, ppl, ratio: float, condition_flag: bool = False
+ ):
+ if ratio == 1.0:
+ return float("-inf")
+ ppl = ppl[ppl != 10000]
+ target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
+ return (
+ ppl.sort(descending=not condition_flag)
+ .values[target_token]
+ .detach()
+ .cpu()
+ .item()
+ )
+
+ def iterative_compress_prompt(
+ self,
+ context: List[str],
+ target_token: float,
+ iterative_size: int = 200,
+ keep_split: bool = False,
+ split_token_id: int = 13,
+ start: int = 0,
+ dynamic_ratio: list = None,
+ condition_compare: bool = False,
+ segments_info: List[List[tuple]] = None,
+ ):
+ if segments_info is None or segments_info == []:
+ iterative_ratios = self.get_dynamic_compression_ratio(
+ context, target_token, iterative_size, dynamic_ratio, start
+ )
+ else:
+ iterative_ratios = self.get_structured_dynamic_compression_ratio(
+ context, iterative_size, dynamic_ratio, start, segments_info
+ )
+ context = "\n\n".join(context)
+ tokenized_text = self.tokenizer(
+ context, return_tensors="pt", add_special_tokens=False
+ )
+ input_ids = tokenized_text["input_ids"].to(self.device)
+ attention_mask = tokenized_text["attention_mask"].to(self.device)
+
+ N = (attention_mask == 1).sum()
+ compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
+ if condition_compare:
+ self_input_ids, self_attention_mask = (
+ input_ids[:, start:],
+ attention_mask[:, start:],
+ )
+ self_compressed_input_ids, self_compressed_attention_mask = (
+ self_input_ids,
+ self_attention_mask,
+ )
+
+ end = min(iterative_size + start, compressed_input_ids.shape[1])
+ threshold, keep_flag = None, None
+ if keep_split:
+ input_ids_numpy = input_ids.cpu().detach().numpy()[0]
+ N = len(input_ids_numpy)
+ keep_flag = [
+ int(
+ (
+ ii > 0
+ and input_ids_numpy[ii] == split_token_id
+ and input_ids_numpy[ii - 1] == split_token_id
+ )
+ or (
+ ii < N - 1
+ and input_ids_numpy[ii] == split_token_id
+ and input_ids_numpy[ii + 1] == split_token_id
+ )
+ )
+ for ii in range(N)
+ ]
+ keep_flag = torch.tensor(keep_flag).to(self.device)
+ past_key_values, past_loss, ready_end = None, None, 0
+ self_past_key_values, self_past_loss, self_ready_end = None, None, 0
+ pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
+ idx = 0
+ while end <= compressed_input_ids.shape[1]:
+ if end > self.max_position_embeddings and past_key_values is not None:
+ # KV-Cache Compression
+ e, s = end - self.max_position_embeddings, min(
+ self.cache_bos_num + start, self.max_position_embeddings
+ )
+ if pop_compressed_input_ids is None:
+ pop_compressed_input_ids = compressed_input_ids[:, :e]
+ else:
+ pop_compressed_input_ids = torch.cat(
+ [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
+ )
+ compressed_input_ids = compressed_input_ids[:, e:]
+ compressed_attention_mask = compressed_attention_mask[:, e:]
+ past_key_values = [
+ [
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
+ ]
+ for k, v in past_key_values
+ ]
+ if keep_flag is not None:
+ keep_flag = keep_flag[e:]
+ end, ready_end = end - e, ready_end - e
+ if condition_compare:
+ s = min(s, self_past_key_values[0][0].shape[2] - e)
+ self_ready_end -= e
+ if pop_self_compressed_input_ids is None:
+ pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
+ else:
+ pop_self_compressed_input_ids = torch.cat(
+ [
+ pop_self_compressed_input_ids,
+ self_compressed_input_ids[:, :e],
+ ],
+ dim=-1,
+ )
+ self_compressed_input_ids = self_compressed_input_ids[:, e:]
+ self_compressed_attention_mask = self_compressed_attention_mask[
+ :, e:
+ ]
+ self_past_key_values = [
+ [
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
+ ]
+ for k, v in self_past_key_values
+ ]
+
+ loss, past_key_values = self.get_ppl(
+ "",
+ "token",
+ compressed_input_ids,
+ compressed_attention_mask,
+ past_key_values=past_key_values,
+ return_kv=True,
+ end=end if idx else None,
+ )
+ if loss.shape[0] == 0:
+ break
+ if past_loss is not None:
+ if end - 1 > len(past_loss):
+ past_loss = torch.cat(
+ [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
+ )
+ past_loss[ready_end : end - 1] = loss
+ loss = past_loss
+ else:
+ past_loss = loss
+ if idx:
+ past_key_values = [
+ [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
+ for k, v in past_key_values
+ ]
+ else:
+ past_key_values = None
+
+ if condition_compare:
+ self_loss, self_past_key_values = self.get_ppl(
+ "",
+ "token",
+ self_compressed_input_ids,
+ self_compressed_attention_mask,
+ past_key_values=self_past_key_values,
+ return_kv=True,
+ end=end - start if idx else None,
+ )
+ if self_past_loss is not None:
+ if end - start - 1 > len(self_past_loss):
+ self_past_loss = torch.cat(
+ [
+ self_past_loss,
+ torch.zeros_like(self_loss)[
+ : end - 1 - start - len(self_past_loss)
+ ],
+ ]
+ )
+ self_past_loss[self_ready_end : end - start - 1] = self_loss
+ self_loss = self_past_loss
+ else:
+ self_past_loss = self_loss
+ if idx:
+ self_past_key_values = [
+ [
+ k[:, :, : end - iterative_size - start],
+ v[:, :, : end - iterative_size - start],
+ ]
+ for k, v in self_past_key_values
+ ]
+ else:
+ self_past_key_values = None
+
+ self_ready_end = (
+ end - start - iterative_size if not (start and idx == 0) else 0
+ )
+ ready_end = end - iterative_size if not (start and idx == 0) else 0
+
+ for delta_end, ratio in iterative_ratios[idx]:
+ loss = past_loss
+ if condition_compare:
+ self_loss = self_past_loss
+ threshold = self.get_estimate_threshold_base_distribution(
+ self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
+ )
+ else:
+ threshold = self.get_estimate_threshold_base_distribution(
+ loss, ratio, False
+ )
+
+ (
+ compressed_input_ids,
+ compressed_attention_mask,
+ keep_flag,
+ end,
+ past_loss,
+ self_past_loss,
+ self_compressed_input_ids,
+ self_compressed_attention_mask,
+ ) = self.get_compressed_input(
+ loss,
+ compressed_input_ids,
+ compressed_attention_mask,
+ end - iterative_size + delta_end,
+ iterative_size=delta_end,
+ threshold=threshold,
+ keep_flag=keep_flag,
+ split_token_id=split_token_id,
+ start=start,
+ self_loss=self_loss if condition_compare else None,
+ self_input_ids=(
+ self_compressed_input_ids if condition_compare else None
+ ),
+ self_attention_mask=(
+ self_compressed_attention_mask if condition_compare else None
+ ),
+ )
+ end += iterative_size
+ idx += 1
+ if pop_compressed_input_ids is not None:
+ compressed_input_ids = torch.cat(
+ [pop_compressed_input_ids, compressed_input_ids], dim=-1
+ )
+ return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
+
+ def recover(
+ self,
+ original_prompt: str,
+ compressed_prompt: str,
+ response: str,
+ ):
+ def match_from_compressed(response_word):
+ response_input_ids = self.tokenizer(
+ response_word, add_special_tokens=False
+ )["input_ids"]
+ response_set, response_c = set(response_input_ids), defaultdict(list)
+ for idx in range(M):
+ if original_input_ids[idx] in response_set:
+ response_c[original_input_ids[idx]].append(idx)
+ res, res_min, res_c = None, float("inf"), 1
+ n = len(response_input_ids)
+ for l in response_c[response_input_ids[0]]:
+ x, y, c = 0, l, 1
+ for x in range(1, n):
+ idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
+ if (
+ idx >= len(response_c[response_input_ids[x]])
+ or response_c[response_input_ids[x]][idx] - y > 10
+ ):
+ continue
+ c += 1
+ y = response_c[response_input_ids[x]][idx]
+ if c > res_c:
+ res_c = c
+ res_min = y - l + 1
+ res = (l, y + 1)
+ elif c == res_c and y - l + 1 < res_min:
+ res_min = y - l + 1
+ res = (l, y + 1)
+
+ if res is None:
+ return response_word
+ # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
+ # l -= 1
+ # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
+ # l -= 1
+ return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
+
+ response_words = response.split(" ")
+
+ original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
+ "input_ids"
+ ]
+ N, M = len(response_words), len(original_input_ids)
+ recovered_response_words = []
+ l = 0
+ while l < N:
+ if response_words[l] not in compressed_prompt:
+ recovered_response_words.append(response_words[l])
+ l += 1
+ continue
+ r = l
+ while (
+ r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
+ ):
+ r += 1
+
+ match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
+ recovered_response_words.append(match_words)
+ l = r + 1
+ return " ".join(recovered_response_words)
+
+ def get_rank_results(
+ self,
+ context: list,
+ question: str,
+ rank_method: str,
+ condition_in_question: str,
+ context_tokens_length: list,
+ ):
+ def get_distance_bm25(corpus, query):
+ from rank_bm25 import BM25Okapi
+
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
+ bm25 = BM25Okapi(tokenized_corpus)
+ tokenized_query = query.split(" ")
+ doc_scores = bm25.get_scores(tokenized_query)
+ idx = [(ii, 0) for ii in (-doc_scores).argsort()]
+ return idx
+
+ def get_distance_gzip(corpus, query):
+ def get_score(x, y):
+ cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
+ cxy = len(gzip.compress(f"{x} {y}".encode()))
+ return (cxy - min(cx, cy)) / max(cx, cy)
+
+ import gzip
+
+ doc_scores = [get_score(doc, query) for doc in corpus]
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+ return idx
+
+ def get_distance_sentbert(corpus, query):
+ from sentence_transformers import SentenceTransformer, util
+
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+ self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
+ self.retrieval_model_name = rank_method
+ doc_embeds = self.retrieval_model.encode(corpus)
+ query = self.retrieval_model.encode(query)
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+ return idx
+
+ def get_distance_openai(corpus, query):
+ import openai
+ from sentence_transformers import util
+
+ openai.api_key = self.open_api_config.get("api_key", "")
+ openai.api_base = self.open_api_config.get(
+ "api_base", "https://api.openai.com/v1"
+ )
+ openai.api_type = self.open_api_config.get("api_type", "open_ai")
+ openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
+ engine = self.open_api_config.get("engine", "text-embedding-ada-002")
+
+ def get_embed(text):
+ return openai.Embedding.create(
+ input=[text.replace("\n", " ")], engine=engine
+ )["data"][0]["embedding"]
+
+ doc_embeds = [get_embed(i) for i in corpus]
+ query = get_embed(query)
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+ return idx
+
+ def get_distance_sentbert_bge(corpus, query):
+ from sentence_transformers import SentenceTransformer, util
+
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+ self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
+ self.retrieval_model_name = rank_method
+ doc_embeds = self.retrieval_model.encode(
+ [i for i in corpus], normalize_embeddings=True
+ )
+ query = self.retrieval_model.encode(query, normalize_embeddings=True)
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+ return idx
+
+ def get_distance_bge_ranker(corpus, query):
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
+
+ pairs = [[i, query] for i in corpus]
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
+ model = (
+ AutoModelForSequenceClassification.from_pretrained(
+ "BAAI/bge-reranker-large"
+ )
+ .eval()
+ .to(self.device)
+ )
+ self.retrieval_model = [tokenizer, model]
+ self.retrieval_model_name = rank_method
+ with torch.no_grad():
+ inputs = self.retrieval_model[0](
+ pairs,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=512,
+ ).to(self.device)
+ scores = (
+ self.retrieval_model[1](**inputs, return_dict=True)
+ .logits.view(
+ -1,
+ )
+ .float()
+ )
+ idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
+ return idx
+
+ def get_distance_bge_llmembedder(corpus, query):
+ from transformers import AutoModel, AutoTokenizer
+
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
+ model = (
+ AutoModel.from_pretrained("BAAI/llm-embedder")
+ .eval()
+ .to(self.device)
+ )
+ self.retrieval_model = [tokenizer, model]
+ self.retrieval_model_name = rank_method
+
+ instruction_qa_query = (
+ "Represent this query for retrieving relevant documents: "
+ )
+ instruction_qa_key = "Represent this document for retrieval: "
+ queries = [instruction_qa_query + query for _ in corpus]
+ keys = [instruction_qa_key + key for key in corpus]
+ with torch.no_grad():
+ query_inputs = self.retrieval_model[0](
+ queries,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=512,
+ ).to(self.device)
+ key_inputs = self.retrieval_model[0](
+ keys,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=512,
+ ).to(self.device)
+ query_outputs = self.retrieval_model[1](**query_inputs)
+ key_outputs = self.retrieval_model[1](**key_inputs)
+ # CLS pooling
+ query_embeddings = query_outputs.last_hidden_state[:, 0]
+ key_embeddings = key_outputs.last_hidden_state[:, 0]
+ # Normalize
+ query_embeddings = torch.nn.functional.normalize(
+ query_embeddings, p=2, dim=1
+ )
+ key_embeddings = torch.nn.functional.normalize(
+ key_embeddings, p=2, dim=1
+ )
+ similarity = query_embeddings @ key_embeddings.T
+ idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
+ return idx
+
+ def get_distance_jinza(corpus, query):
+ from numpy.linalg import norm
+
+ from transformers import AutoModel
+
+ def cos_sim(a, b):
+ return (a @ b.T) / (norm(a) * norm(b))
+
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+ model = (
+ AutoModel.from_pretrained(
+ "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
+ )
+ .eval()
+ .to(self.device)
+ )
+ self.retrieval_model = model
+ self.retrieval_model_name = rank_method
+
+ doc_embeds = self.retrieval_model.encode(corpus)
+ query = self.retrieval_model.encode(query)
+ doc_scores = cos_sim(doc_embeds, query)
+ idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
+ return idx
+
+ def get_distance_voyageai(corpus, query):
+ import voyageai
+ from sentence_transformers import util
+
+ voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
+
+ def get_embed(text):
+ return voyageai.get_embedding(text, model="voyage-01")
+
+ doc_embeds = [get_embed(i) for i in corpus]
+ query = get_embed(query)
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+ return idx
+
+ def get_distance_cohere(corpus, query):
+ import cohere
+
+ api_key = self.open_api_config.get("cohere_api_key", "")
+ co = cohere.Client(api_key)
+ results = co.rerank(
+ model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
+ )
+ c_map = {jj: ii for ii, jj in enumerate(corpus)}
+ doc_rank = [c_map[ii.document["text"]] for ii in results]
+ idx = [(ii, 0) for ii in doc_rank]
+ return idx
+
+ def get_distance_longllmlingua(corpus, query):
+ context_ppl = [
+ self.get_condition_ppl(
+ d,
+ query
+ + " We can get the answer to this question in the given documents.",
+ condition_in_question,
+ )
+ - dl * 2 / 250 * 0
+ for d, dl in zip(corpus, context_tokens_length)
+ ]
+ sort_direct = -1 if condition_in_question == "none" else 1
+ ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
+ return ys
+
+ method = None
+ if rank_method == "bm25":
+ method = get_distance_bm25
+ elif rank_method == "gzip":
+ method = get_distance_gzip
+ elif rank_method == "sentbert":
+ method = get_distance_sentbert
+ elif rank_method == "openai":
+ method = get_distance_openai
+ elif rank_method in ["longllmlingua", "llmlingua"]:
+ method = get_distance_longllmlingua
+ elif rank_method == "bge":
+ method = get_distance_sentbert_bge
+ elif rank_method == "bge_reranker":
+ method = get_distance_bge_ranker
+ elif rank_method == "bge_llmembedder":
+ method = get_distance_bge_llmembedder
+ elif rank_method == "jinza":
+ method = get_distance_jinza
+ elif rank_method == "voyageai":
+ method = get_distance_voyageai
+ elif rank_method == "cohere":
+ method = get_distance_cohere
+ return method(context, question)
+
+ def segment_structured_context(
+ self,
+ context: List[str],
+ global_rate: float,
+ ):
+ new_context, context_segs, context_segs_rate, context_segs_compress = (
+ [],
+ [],
+ [],
+ [],
+ )
+ for text in context:
+ if not text.startswith(""):
+ text = text + ""
+
+ # Regular expression to match content, allowing rate and compress in any order
+ pattern = r"([^<]+)"
+ matches = re.findall(pattern, text)
+
+ # Extracting segment contents
+ segments = [match[4] for match in matches]
+
+ # Extracting rate and compress, considering their possible positions
+ segs_rate = [
+ float(match[0]) if match[0] else (float(match[2]) if match[2] else None)
+ for match in matches
+ ]
+ segs_compress = [
+ (
+ match[1] == "True"
+ if match[1]
+ else (match[3] == "True" if match[3] else None)
+ )
+ for match in matches
+ ]
+
+ segs_compress = [
+ compress if compress is not None else True for compress in segs_compress
+ ]
+ segs_rate = [
+ rate if rate else (global_rate if compress else 1.0)
+ for rate, compress in zip(segs_rate, segs_compress)
+ ]
+ assert (
+ len(segments) == len(segs_rate) == len(segs_compress)
+ ), "The number of segments, rates, and compress flags should be the same."
+ assert all(
+ seg_rate <= 1.0 for seg_rate in segs_rate
+ ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
+
+ new_context.append("".join(segments))
+ context_segs.append(segments)
+ context_segs_rate.append(segs_rate)
+ context_segs_compress.append(segs_compress)
+
+ return new_context, context_segs, context_segs_rate, context_segs_compress
+
+ def concate_segment_info(
+ self,
+ segment_info: List[List[tuple]],
+ ):
+ new_segment_info = []
+ for i, (seg_len, seg_ratio, seg_compress) in enumerate(segment_info):
+ if (
+ new_segment_info
+ and new_segment_info[-1][1] == seg_ratio
+ and new_segment_info[-1][2] == seg_compress
+ ):
+ new_segment_info[-1] = (
+ new_segment_info[-1][0] + seg_len,
+ seg_ratio,
+ seg_compress,
+ )
+ else:
+ new_segment_info.append((seg_len, seg_ratio, seg_compress))
+ return new_segment_info
+
+ def __get_context_prob(
+ self,
+ context_list: list,
+ token_to_word="mean",
+ force_tokens: List[str]=[],
+ token_map: dict={},
+ force_reserve_digit: bool=False,
+ ):
+ chunk_list = []
+ for chunks in context_list:
+ for c in chunks:
+ chunk_list.append(c)
+
+ dataset = TokenClfDataset(
+ chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
+ )
+ dataloader = DataLoader(
+ dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
+ )
+
+ chunk_probs = []
+ chunk_words = []
+ with torch.no_grad():
+ for batch in dataloader:
+ ids = batch["ids"].to(self.device, dtype=torch.long)
+ mask = batch["mask"].to(self.device, dtype=torch.long) == 1
+
+ outputs = self.model(input_ids=ids, attention_mask=mask)
+ loss, logits = outputs.loss, outputs.logits
+ probs = F.softmax(logits, dim=-1)
+
+ for j in range(ids.shape[0]):
+ _probs = probs[j, :, 1]
+ _ids = ids[j]
+ _mask = mask[j]
+
+ active_probs = torch.masked_select(_probs, _mask)
+ active_ids = torch.masked_select(_ids, _mask)
+
+ tokens = self.tokenizer.convert_ids_to_tokens(
+ active_ids.squeeze().tolist()
+ )
+ token_probs = [prob for prob in active_probs.cpu().numpy()]
+
+ (
+ words,
+ valid_token_probs,
+ valid_token_probs_no_force,
+ ) = self.__merge_token_to_word(
+ tokens,
+ token_probs,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ )
+ word_probs_no_force = self.__token_prob_to_word_prob(
+ valid_token_probs_no_force, convert_mode=token_to_word
+ )
+
+ if "xlm-roberta-large" in self.model_name:
+ for i in range(len(words)):
+ words[i] = words[i].lstrip("▁")
+ chunk_words.append(words)
+ chunk_probs.append(word_probs_no_force)
+
+ prev_idx = 0
+ context_probs = []
+ context_words = []
+ for chunk_list in context_list:
+ n_chunk = len(chunk_list)
+ context_probs.append([])
+ context_words.append([])
+ for i in range(n_chunk):
+ context_probs[-1].extend(chunk_probs[prev_idx + i])
+ context_words[-1].extend(chunk_words[prev_idx + i])
+ prev_idx = prev_idx + n_chunk
+ context_probs = [sum(probs) / len(probs) for probs in context_probs]
+ return context_probs, context_words
+
+ def __chunk_context(self, origin_text, chunk_end_tokens):
+ origin_list = []
+ origin_tokens = self.tokenizer.tokenize(origin_text)
+ n = len(origin_tokens)
+ st = 0
+ while st < n:
+ if st + self.max_seq_len > n - 1:
+ chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n])
+ origin_list.append(chunk)
+ break
+ else:
+ ed = st + self.max_seq_len
+ for j in range(0, ed - st):
+ if origin_tokens[ed - j] in chunk_end_tokens:
+ ed = ed - j
+ break
+ chunk = self.tokenizer.convert_tokens_to_string(
+ origin_tokens[st : ed + 1]
+ )
+ origin_list.append(chunk)
+ st = ed + 1
+ return origin_list
+
+ def __merge_token_to_word(self, tokens, token_probs, force_tokens, token_map, force_reserve_digit):
+ words = []
+ word_probs = []
+ word_probs_no_force = []
+
+ for token, prob in zip(tokens, token_probs):
+ if token in self.special_tokens:
+ continue
+ # add a new word
+ elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map):
+ pure_token = get_pure_token(token, self.model_name)
+ prob_no_force = prob
+ if pure_token in force_tokens or pure_token in set(token_map.values()):
+ prob=1.0
+ token = replace_added_token(token, token_map)
+ words.append(token)
+ word_probs.append(
+ [
+ 1.0
+ if force_reserve_digit
+ and bool(re.search(r"\d", token))
+ else prob
+ ]
+ )
+ word_probs_no_force.append([prob_no_force])
+ # concatenate with previous token
+ else:
+ pure_token = get_pure_token(token, self.model_name)
+ words[-1] += pure_token
+ word_probs[-1].append(
+ 1.0
+ if force_reserve_digit
+ and bool(re.search(r"\d", token))
+ else prob
+ )
+ word_probs_no_force[-1].append(prob_no_force)
+
+ return words, word_probs, word_probs_no_force
+
+ def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"):
+ if convert_mode == "mean":
+ word_probs = [sum(p) / len(p) for p in token_probs]
+ elif convert_mode == "first":
+ word_probs = [p[0] for p in token_probs]
+ else:
+ raise NotImplementedError()
+
+ return word_probs
+
+ def __compress(
+ self,
+ context_list: list,
+ reduce_rate: float=0.5,
+ token_to_word: str="mean",
+ force_tokens: List[str]=[],
+ token_map: dict={},
+ force_reserve_digit: bool=False,
+ drop_consecutive: bool=False,
+ ):
+ def split_string_to_words(input_string):
+ pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]'
+ result = re.findall(pattern, input_string)
+ return result
+ # print(force_tokens, token_map, force_reserve_digit, drop_consecutive)
+ if reduce_rate <= 0:
+ words, word_labels = [], []
+ for i in range(len(context_list)):
+ chunk_list = context_list[i]
+ chunk_words = []
+ chunk_word_labels = []
+ for j in range(len(chunk_list)):
+ # replace to original token
+ for ori_token, new_token in token_map.items():
+ chunk_list[j] = chunk_list[j].replace(new_token, ori_token)
+ ws = split_string_to_words(chunk_list[j])
+ chunk_words.extend(ws)
+ chunk_word_labels.extend([1 for _ in range(len(ws))])
+ context_list[i] = "".join(chunk_list)
+ words.append(chunk_words)
+ word_labels.append(chunk_word_labels)
+ return context_list, words, word_labels
+
+ chunk_list = []
+ for chunks in context_list:
+ for c in chunks:
+ chunk_list.append(c)
+
+ dataset = TokenClfDataset(
+ chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
+ )
+ dataloader = DataLoader(
+ dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
+ )
+
+ compressed_chunk_list = []
+ word_list = []
+ word_label_list = []
+ with torch.no_grad():
+ for batch in dataloader:
+ ids = batch["ids"].to(self.device, dtype=torch.long)
+ mask = batch["mask"].to(self.device, dtype=torch.long) == 1
+
+ outputs = self.model(input_ids=ids, attention_mask=mask)
+ loss, logits = outputs.loss, outputs.logits
+ probs = F.softmax(logits, dim=-1)
+
+ for j in range(ids.shape[0]):
+ chunk_probs = probs[j, :, 1]
+ chunk_ids = ids[j]
+ chunk_mask = mask[j]
+
+ active_probs = torch.masked_select(chunk_probs, chunk_mask)
+ active_ids = torch.masked_select(chunk_ids, chunk_mask)
+
+ tokens = self.tokenizer.convert_ids_to_tokens(
+ active_ids.squeeze().tolist()
+ )
+ token_probs = [prob for prob in active_probs.cpu().numpy()]
+
+ words, valid_token_probs, _ = self.__merge_token_to_word(
+ tokens=tokens,
+ token_probs=token_probs,
+ force_tokens=force_tokens,
+ token_map=token_map,
+ force_reserve_digit=force_reserve_digit,
+ )
+ word_probs = self.__token_prob_to_word_prob(
+ valid_token_probs, convert_mode=token_to_word
+ )
+
+ if drop_consecutive:
+ threshold = np.percentile(word_probs, int(100 * reduce_rate))
+ is_token_between = False
+ prev = None
+ for i, (word, word_prob) in enumerate(zip(words, word_probs)):
+ if word in force_tokens:
+ if is_token_between:
+ is_token_between = False
+ elif not is_token_between and word == prev:
+ word_probs[i] = 0.0
+ prev = word
+ else:
+ is_token_between |= word_prob > threshold
+
+ # calculate compression ratio w.r.t. gpt-4 tokenizer
+ new_token_probs = []
+ for word, word_prob in zip(words, word_probs):
+ num_token = len(self.oai_tokenizer.encode(word))
+ new_token_probs.extend([word_prob for _ in range(num_token)])
+ threshold = np.percentile(
+ new_token_probs, int(100 * reduce_rate + 1)
+ )
+
+ keep_words = []
+ word_labels = []
+ assert len(words) == len(word_probs)
+ for word, word_porb in zip(words, word_probs):
+ if word_porb > threshold:
+ if (
+ drop_consecutive
+ and word in force_tokens
+ and len(keep_words) > 0
+ and keep_words[-1] == word
+ ):
+ word_labels.append(0)
+ else:
+ keep_words.append(word)
+ word_labels.append(1)
+ else:
+ word_labels.append(0)
+ keep_str = self.tokenizer.convert_tokens_to_string(keep_words)
+ if "xlm-roberta-large" in self.model_name:
+ for i in range(len(words)):
+ words[i] = words[i].lstrip("▁")
+
+ compressed_chunk_list.append(keep_str)
+ word_list.append(words[:])
+ word_label_list.append(word_labels[:])
+
+ compressed_context_list = []
+ original_word_list = []
+ original_word_label_list = []
+ prev_idx = 0
+ for chunk_list in context_list:
+ n_chunk = len(chunk_list)
+ compressed_context_list.append(
+ "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk])
+ )
+ original_word_list.append([])
+ original_word_label_list.append([])
+ for i in range(n_chunk):
+ original_word_list[-1].extend(word_list[prev_idx + i])
+ original_word_label_list[-1].extend(word_label_list[prev_idx + i])
+ prev_idx = prev_idx + n_chunk
+
+ return compressed_context_list, original_word_list, original_word_label_list