File size: 5,124 Bytes
b1a665d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import logging
import traceback
from typing import Dict, List, Any
from nemo_skills.inference.server.code_execution_model import get_code_execution_model
from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.prompt.utils import get_prompt
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""Custom endpoint handler for NeMo Skills code execution inference."""
def __init__(self):
"""
Initialize the handler with the model and prompt configurations.
"""
self.model = None
self.prompt = None
self.initialized = False
# Configuration
self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
def _initialize_components(self):
"""Initialize the model, sandbox, and prompt components lazily."""
if self.initialized:
return
try:
logger.info("Initializing sandbox...")
sandbox = get_sandbox(sandbox_type="local")
logger.info("Initializing code execution model...")
self.model = get_code_execution_model(
server_type="vllm",
sandbox=sandbox,
host="127.0.0.1",
port=5000
)
logger.info("Initializing prompt...")
if self.prompt_config_path:
self.prompt = get_prompt(
prompt_config=self.prompt_config_path,
prompt_template=self.prompt_template_path
)
self.initialized = True
logger.info("All components initialized successfully")
except Exception as e:
logger.warning(f"Failed to initialize the model")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Args:
data: Dictionary containing the request data
Expected keys:
- inputs: str or list of str - the input prompts/problems
- parameters: dict (optional) - generation parameters
Returns:
List of dictionaries containing the generated responses
"""
try:
# Initialize components if not already done
self._initialize_components()
# Extract inputs and parameters
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle both single string and list of strings
if isinstance(inputs, str):
prompts = [inputs]
elif isinstance(inputs, list):
prompts = inputs
else:
raise ValueError("inputs must be a string or list of strings")
# If we have a prompt template configured, format the inputs
if self.prompt is not None:
formatted_prompts = []
for prompt_text in prompts:
formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
formatted_prompts.append(formatted_prompt)
prompts = formatted_prompts
# Get code execution arguments from prompt if available
extra_generate_params = {}
if self.prompt is not None:
extra_generate_params = self.prompt.get_code_execution_args()
# Set default generation parameters
generation_params = {
"tokens_to_generate": 12000,
"temperature": 0.0,
"top_p": 0.95,
"top_k": 0,
"repetition_penalty": 1.0,
"random_seed": 0,
}
# Update with provided parameters
generation_params.update(parameters)
generation_params.update(extra_generate_params)
logger.info(f"Processing {len(prompts)} prompt(s)")
# Generate responses
outputs = self.model.generate(
prompts=prompts,
**generation_params
)
# Format outputs
results = []
for output in outputs:
result = {
"generated_text": output.get("generation", ""),
"code_rounds_executed": output.get("code_rounds_executed", 0),
}
results.append(result)
logger.info(f"Successfully processed {len(results)} request(s)")
return results
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
logger.error(traceback.format_exc())
return [{"error": str(e), "generated_text": ""}] |