Spaces:
Runtime error
Runtime error
| import math | |
| import random | |
| from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM | |
| from distilabel.steps.tasks import TextGeneration | |
| from synthetic_dataset_generator.constants import ( | |
| API_KEYS, | |
| DEFAULT_BATCH_SIZE, | |
| HUGGINGFACE_BASE_URL, | |
| MODEL, | |
| OLLAMA_BASE_URL, | |
| OPENAI_BASE_URL, | |
| TOKENIZER_ID, | |
| VLLM_BASE_URL, | |
| ) | |
| TOKEN_INDEX = 0 | |
| def _get_next_api_key(): | |
| global TOKEN_INDEX | |
| api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)] | |
| TOKEN_INDEX += 1 | |
| return api_key | |
| def _get_prompt_rewriter(): | |
| generation_kwargs = { | |
| "temperature": 1, | |
| } | |
| system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new." | |
| prompt_rewriter = TextGeneration( | |
| llm=_get_llm(generation_kwargs=generation_kwargs), | |
| system_prompt=system_prompt, | |
| use_system_prompt=True, | |
| ) | |
| prompt_rewriter.load() | |
| return prompt_rewriter | |
| def get_rewritten_prompts(prompt: str, num_rows: int): | |
| prompt_rewriter = _get_prompt_rewriter() | |
| # create prompt rewrites | |
| inputs = [ | |
| {"instruction": f"Original prompt: {prompt} \nRewritten prompt: "} | |
| for i in range(math.floor(num_rows / 100)) | |
| ] | |
| n_processed = 0 | |
| prompt_rewrites = [prompt] | |
| while n_processed < num_rows: | |
| batch = list( | |
| prompt_rewriter.process( | |
| inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE] | |
| ) | |
| ) | |
| prompt_rewrites += [entry["generation"] for entry in batch[0]] | |
| n_processed += DEFAULT_BATCH_SIZE | |
| random.seed(a=random.randint(0, 2**32 - 1)) | |
| return prompt_rewrites | |
| def _get_llm_class() -> str: | |
| if OPENAI_BASE_URL: | |
| return "OpenAILLM" | |
| elif OLLAMA_BASE_URL: | |
| return "OllamaLLM" | |
| elif HUGGINGFACE_BASE_URL: | |
| return "InferenceEndpointsLLM" | |
| elif VLLM_BASE_URL: | |
| return "ClientvLLM" | |
| else: | |
| return "InferenceEndpointsLLM" | |
| def _get_llm(use_magpie_template=False, **kwargs): | |
| if OPENAI_BASE_URL: | |
| llm = OpenAILLM( | |
| model=MODEL, | |
| base_url=OPENAI_BASE_URL, | |
| api_key=_get_next_api_key(), | |
| **kwargs, | |
| ) | |
| if "generation_kwargs" in kwargs: | |
| if "stop_sequences" in kwargs["generation_kwargs"]: | |
| kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ | |
| "stop_sequences" | |
| ] | |
| del kwargs["generation_kwargs"]["stop_sequences"] | |
| if "do_sample" in kwargs["generation_kwargs"]: | |
| del kwargs["generation_kwargs"]["do_sample"] | |
| elif OLLAMA_BASE_URL: | |
| if "generation_kwargs" in kwargs: | |
| if "max_new_tokens" in kwargs["generation_kwargs"]: | |
| kwargs["generation_kwargs"]["num_predict"] = kwargs[ | |
| "generation_kwargs" | |
| ]["max_new_tokens"] | |
| del kwargs["generation_kwargs"]["max_new_tokens"] | |
| if "stop_sequences" in kwargs["generation_kwargs"]: | |
| kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ | |
| "stop_sequences" | |
| ] | |
| del kwargs["generation_kwargs"]["stop_sequences"] | |
| if "do_sample" in kwargs["generation_kwargs"]: | |
| del kwargs["generation_kwargs"]["do_sample"] | |
| options = kwargs["generation_kwargs"] | |
| del kwargs["generation_kwargs"] | |
| kwargs["generation_kwargs"] = {} | |
| kwargs["generation_kwargs"]["options"] = options | |
| llm = OllamaLLM( | |
| model=MODEL, | |
| host=OLLAMA_BASE_URL, | |
| tokenizer_id=TOKENIZER_ID or MODEL, | |
| use_magpie_template=use_magpie_template, | |
| **kwargs, | |
| ) | |
| elif HUGGINGFACE_BASE_URL: | |
| kwargs["generation_kwargs"]["do_sample"] = True | |
| llm = InferenceEndpointsLLM( | |
| api_key=_get_next_api_key(), | |
| base_url=HUGGINGFACE_BASE_URL, | |
| tokenizer_id=TOKENIZER_ID or MODEL, | |
| use_magpie_template=use_magpie_template, | |
| **kwargs, | |
| ) | |
| elif VLLM_BASE_URL: | |
| if "generation_kwargs" in kwargs: | |
| if "do_sample" in kwargs["generation_kwargs"]: | |
| del kwargs["generation_kwargs"]["do_sample"] | |
| llm = ClientvLLM( | |
| base_url=VLLM_BASE_URL, | |
| model=MODEL, | |
| tokenizer=TOKENIZER_ID or MODEL, | |
| api_key=_get_next_api_key(), | |
| use_magpie_template=use_magpie_template, | |
| **kwargs, | |
| ) | |
| else: | |
| llm = InferenceEndpointsLLM( | |
| api_key=_get_next_api_key(), | |
| tokenizer_id=TOKENIZER_ID or MODEL, | |
| model_id=MODEL, | |
| use_magpie_template=use_magpie_template, | |
| **kwargs, | |
| ) | |
| return llm | |
| try: | |
| llm = _get_llm() | |
| llm.load() | |
| llm.generate([[{"content": "Hello, world!", "role": "user"}]]) | |
| except Exception as e: | |
| raise Exception(f"Error loading {llm.__class__.__name__}: {e}") | |