|
import os |
|
import logging |
|
import traceback |
|
from typing import Dict, List, Any |
|
|
|
from nemo_skills.inference.server.code_execution_model import get_code_execution_model |
|
from nemo_skills.code_execution.sandbox import get_sandbox |
|
from nemo_skills.prompt.utils import get_prompt |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EndpointHandler: |
|
"""Custom endpoint handler for NeMo Skills code execution inference.""" |
|
|
|
def __init__(self): |
|
""" |
|
Initialize the handler with the model and prompt configurations. |
|
""" |
|
self.model = None |
|
self.prompt = None |
|
self.initialized = False |
|
|
|
|
|
self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math") |
|
self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct") |
|
|
|
def _initialize_components(self): |
|
"""Initialize the model, sandbox, and prompt components lazily.""" |
|
if self.initialized: |
|
return |
|
|
|
try: |
|
logger.info("Initializing sandbox...") |
|
sandbox = get_sandbox(sandbox_type="local") |
|
|
|
logger.info("Initializing code execution model...") |
|
self.model = get_code_execution_model( |
|
server_type="vllm", |
|
sandbox=sandbox, |
|
host="127.0.0.1", |
|
port=5000 |
|
) |
|
|
|
logger.info("Initializing prompt...") |
|
if self.prompt_config_path: |
|
self.prompt = get_prompt( |
|
prompt_config=self.prompt_config_path, |
|
prompt_template=self.prompt_template_path |
|
) |
|
|
|
self.initialized = True |
|
logger.info("All components initialized successfully") |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to initialize the model") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Process inference requests. |
|
|
|
Args: |
|
data: Dictionary containing the request data |
|
Expected keys: |
|
- inputs: str or list of str - the input prompts/problems |
|
- parameters: dict (optional) - generation parameters |
|
|
|
Returns: |
|
List of dictionaries containing the generated responses |
|
""" |
|
try: |
|
|
|
self._initialize_components() |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
if isinstance(inputs, str): |
|
prompts = [inputs] |
|
elif isinstance(inputs, list): |
|
prompts = inputs |
|
else: |
|
raise ValueError("inputs must be a string or list of strings") |
|
|
|
|
|
if self.prompt is not None: |
|
formatted_prompts = [] |
|
for prompt_text in prompts: |
|
formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8}) |
|
formatted_prompts.append(formatted_prompt) |
|
prompts = formatted_prompts |
|
|
|
|
|
extra_generate_params = {} |
|
if self.prompt is not None: |
|
extra_generate_params = self.prompt.get_code_execution_args() |
|
|
|
|
|
generation_params = { |
|
"tokens_to_generate": 12000, |
|
"temperature": 0.0, |
|
"top_p": 0.95, |
|
"top_k": 0, |
|
"repetition_penalty": 1.0, |
|
"random_seed": 0, |
|
} |
|
|
|
|
|
generation_params.update(parameters) |
|
generation_params.update(extra_generate_params) |
|
|
|
logger.info(f"Processing {len(prompts)} prompt(s)") |
|
|
|
|
|
outputs = self.model.generate( |
|
prompts=prompts, |
|
**generation_params |
|
) |
|
|
|
|
|
results = [] |
|
for output in outputs: |
|
result = { |
|
"generated_text": output.get("generation", ""), |
|
"code_rounds_executed": output.get("code_rounds_executed", 0), |
|
} |
|
results.append(result) |
|
|
|
logger.info(f"Successfully processed {len(results)} request(s)") |
|
return results |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing request: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
return [{"error": str(e), "generated_text": ""}] |