Spaces:
Running
Running
| import copy | |
| import os | |
| from datetime import timedelta | |
| import sys | |
| from time import time | |
| from pathlib import Path | |
| from typing import List, Literal, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| import transformers | |
| from accelerate import ( | |
| Accelerator, | |
| DistributedType, | |
| InitProcessGroupKwargs, | |
| find_executable_batch_size, | |
| ) | |
| from packaging import version | |
| from peft import PeftModel | |
| from peft import __version__ as PEFT_VERSION | |
| from tqdm import tqdm | |
| from transformers.models.auto.modeling_auto import ( | |
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, | |
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, | |
| ) | |
| from transformers import TextStreamer | |
| from lm_eval import utils | |
| from lm_eval.api.instance import Instance | |
| from lm_eval.api.model import TemplateLM | |
| from lm_eval.api.registry import register_model | |
| from lm_eval.models.utils import ( | |
| Collator, | |
| clear_torch_cache, | |
| get_dtype, | |
| pad_and_concat, | |
| stop_sequences_criteria, | |
| ) | |
| from lm_eval.models.huggingface import HFLM | |
| class StopWatch(TextStreamer): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.start_prefilling = None | |
| self.prefilling_time = None | |
| self.start_decoding = None | |
| self.decoding_time = None | |
| self.decoding_iterations = 0 | |
| def put(self, value): | |
| if self.start_prefilling is None: | |
| self.start_prefilling = time() | |
| return | |
| elif self.prefilling_time is None: | |
| self.prefilling_time = time() - self.start_prefilling | |
| self.start_decoding = time() | |
| self.decoding_iterations += 1 | |
| return | |
| def end(self): | |
| if self.decoding_time is None and self.start_decoding is not None: | |
| self.decoding_time = time() - self.start_decoding | |
| return | |
| class HFLMWithMeasurement(HFLM): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def _model_generate(self, context, max_length, stop, **generation_kwargs): | |
| # temperature = 0.0 if not set | |
| # if do_sample is false and temp==0.0: | |
| # remove temperature, as do_sample=False takes care of this | |
| # and we don't want a warning from HF | |
| generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) | |
| do_sample = generation_kwargs.get("do_sample", None) | |
| # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies | |
| if generation_kwargs.get("temperature") == 0.0 and do_sample is None: | |
| generation_kwargs["do_sample"] = do_sample = False | |
| if do_sample is False and generation_kwargs.get("temperature") == 0.0: | |
| generation_kwargs.pop("temperature") | |
| # build stopping criteria | |
| stopping_criteria = stop_sequences_criteria( | |
| self.tokenizer, stop, context.shape[1], context.shape[0] | |
| ) | |
| stop_watch = StopWatch(self.tokenizer) | |
| start = time() | |
| res = self.model.generate( | |
| input_ids=context, | |
| max_length=max_length, | |
| stopping_criteria=stopping_criteria, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| use_cache=True, | |
| streamer=stop_watch, | |
| **generation_kwargs, | |
| ) | |
| end = time() | |
| batch_size = context.shape[0] | |
| output_length = stop_watch.decoding_iterations | |
| end_to_end_time = (end - start) / batch_size | |
| prefilling_time = stop_watch.prefilling_time / batch_size | |
| decoding_time = stop_watch.decoding_time / batch_size | |
| token_per_sec = output_length / decoding_time | |
| return res, end_to_end_time, prefilling_time, token_per_sec | |
| def generate_until( | |
| self, requests: List[Instance], disable_tqdm: bool = False | |
| ) -> List[str]: | |
| res = [] | |
| def _collate(req: Tuple[str, dict]): | |
| """Defines the key for the sorted method""" | |
| # the negative sign on len(toks) sorts descending - this has a few advantages: | |
| # - time estimates will always be over not underestimates, which is more useful for planning | |
| # - to know the size of a batch when going through the list, you know the first one is always the batch | |
| # padded context length. this is useful to simplify the batching logic and more importantly to make | |
| # automatic adaptive batches much much easier to implement | |
| # - any OOMs will happen right away rather than near the end | |
| toks = self.tok_encode(req[0]) | |
| return -len(toks), req[0] | |
| pbar = tqdm( | |
| total=len(requests), | |
| disable=(disable_tqdm or (self.rank != 0)), | |
| desc="Running generate_until requests", | |
| ) | |
| adaptive_batch_size = None | |
| if self.batch_size == "auto": | |
| # using rolling window with maximum context | |
| print("Passed argument batch_size = auto. Detecting largest batch size") | |
| batch_size = self._detect_batch_size() | |
| print(f"Determined Largest batch size: {batch_size}") | |
| adaptive_batch_size = batch_size | |
| # for each different set of kwargs, we execute all requests, by batch. | |
| batch_size = ( | |
| self.batch_size | |
| if self.batch_size != "auto" | |
| else adaptive_batch_size | |
| if adaptive_batch_size is not None | |
| else 0 | |
| ) | |
| batch_fn = ( | |
| self._batch_scheduler | |
| if self.batch_size == "auto" and not adaptive_batch_size | |
| else None | |
| ) | |
| # we group requests by their generation_kwargs, | |
| # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling | |
| # in the same batch. | |
| # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) | |
| re_ords = Collator( | |
| [reg.args for reg in requests], | |
| sort_fn=_collate, | |
| group_by="gen_kwargs", | |
| group_fn=lambda x: x[1], | |
| ) | |
| chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) | |
| for chunk in chunks: | |
| contexts, all_gen_kwargs = zip(*chunk) | |
| # we assume all gen kwargs in the batch are the same | |
| # this is safe to assume because the `grouper` object ensures it. | |
| gen_kwargs = all_gen_kwargs[0] | |
| # unpack our keyword arguments. | |
| until = None | |
| if isinstance(gen_kwargs, dict): | |
| kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 | |
| if "until" in kwargs.keys(): | |
| until = kwargs.pop("until") | |
| if isinstance(until, str): | |
| until = [kwargs] | |
| elif not isinstance(until, list): | |
| raise ValueError( | |
| f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" | |
| ) | |
| # add EOS token to stop sequences | |
| eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) | |
| if not until: | |
| until = [eos] | |
| else: | |
| until.append(eos) | |
| if "max_gen_toks" in kwargs.keys(): | |
| max_gen_toks = kwargs.pop("max_gen_toks") | |
| else: | |
| max_gen_toks = self.max_gen_toks | |
| # set the max length in tokens of inputs ("context_enc") | |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: | |
| # max len for inputs = max length, minus room to generate the max new tokens | |
| max_ctx_len = self.max_length - max_gen_toks | |
| elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: | |
| # max len for inputs = encoder's whole max_length | |
| max_ctx_len = self.max_length | |
| # encode, pad, and truncate contexts for this batch | |
| context_enc, attn_masks = self.tok_batch_encode( | |
| contexts, | |
| left_truncate_len=max_ctx_len, | |
| truncation=self.truncation, | |
| ) | |
| context_enc = context_enc.to(self.device) | |
| attn_masks = attn_masks.to(self.device) | |
| if "max_length" not in kwargs: | |
| kwargs["max_length"] = context_enc.shape[1] + max_gen_toks | |
| # perform batched generation | |
| cont, end_to_end_time, prefilling_time, token_per_sec = self._model_generate( | |
| context=context_enc, | |
| attention_mask=attn_masks, | |
| stop=until, | |
| **kwargs, | |
| ) | |
| cont_toks_list = cont.tolist() | |
| for cont_toks, context in zip(cont_toks_list, contexts): | |
| # discard context + left-padding toks if using causal decoder-only LM | |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: | |
| cont_toks = cont_toks[context_enc.shape[1] :] | |
| s = self.tok_decode(cont_toks) | |
| # use secondary stop seqs to cut off should-have-been-stopped content post-hoc | |
| for term in until: | |
| if len(term) > 0: | |
| # ignore '' separator, | |
| # for seq2seq case where self.tok_decode(self.eot_token_id) = '' | |
| s = s.split(term)[0] | |
| res.append((s, end_to_end_time, prefilling_time, token_per_sec)) | |
| self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) | |
| pbar.update(1) | |
| # reorder this group of results back to original unsorted form | |
| res = re_ords.get_original(res) | |
| pbar.close() | |
| return res | |