igitman's picture
inference_endpoint (#2)
b1a665d verified
import os
import logging
import traceback
from typing import Dict, List, Any
from nemo_skills.inference.server.code_execution_model import get_code_execution_model
from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.prompt.utils import get_prompt
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""Custom endpoint handler for NeMo Skills code execution inference."""
def __init__(self):
"""
Initialize the handler with the model and prompt configurations.
"""
self.model = None
self.prompt = None
self.initialized = False
# Configuration
self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
def _initialize_components(self):
"""Initialize the model, sandbox, and prompt components lazily."""
if self.initialized:
return
try:
logger.info("Initializing sandbox...")
sandbox = get_sandbox(sandbox_type="local")
logger.info("Initializing code execution model...")
self.model = get_code_execution_model(
server_type="vllm",
sandbox=sandbox,
host="127.0.0.1",
port=5000
)
logger.info("Initializing prompt...")
if self.prompt_config_path:
self.prompt = get_prompt(
prompt_config=self.prompt_config_path,
prompt_template=self.prompt_template_path
)
self.initialized = True
logger.info("All components initialized successfully")
except Exception as e:
logger.warning(f"Failed to initialize the model")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Args:
data: Dictionary containing the request data
Expected keys:
- inputs: str or list of str - the input prompts/problems
- parameters: dict (optional) - generation parameters
Returns:
List of dictionaries containing the generated responses
"""
try:
# Initialize components if not already done
self._initialize_components()
# Extract inputs and parameters
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both single string and list of strings
if isinstance(inputs, str):
prompts = [inputs]
elif isinstance(inputs, list):
prompts = inputs
else:
raise ValueError("inputs must be a string or list of strings")
# If we have a prompt template configured, format the inputs
if self.prompt is not None:
formatted_prompts = []
for prompt_text in prompts:
formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
formatted_prompts.append(formatted_prompt)
prompts = formatted_prompts
# Get code execution arguments from prompt if available
extra_generate_params = {}
if self.prompt is not None:
extra_generate_params = self.prompt.get_code_execution_args()
# Set default generation parameters
generation_params = {
"tokens_to_generate": 12000,
"temperature": 0.0,
"top_p": 0.95,
"top_k": 0,
"repetition_penalty": 1.0,
"random_seed": 0,
}
# Update with provided parameters
generation_params.update(parameters)
generation_params.update(extra_generate_params)
logger.info(f"Processing {len(prompts)} prompt(s)")
# Generate responses
outputs = self.model.generate(
prompts=prompts,
**generation_params
)
# Format outputs
results = []
for output in outputs:
result = {
"generated_text": output.get("generation", ""),
"code_rounds_executed": output.get("code_rounds_executed", 0),
}
results.append(result)
logger.info(f"Successfully processed {len(results)} request(s)")
return results
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
logger.error(traceback.format_exc())
return [{"error": str(e), "generated_text": ""}]