inference_endpoint (#2)
Browse files- added endpoint scripts (1259b94c50e3fa4f31bce132de4f3ec7596aaa5b)
Co-authored-by: Ivan Moshkov <[email protected]>
- Dockerfile +26 -0
- entrypoint.sh +32 -0
- handler.py +139 -0
- server.py +77 -0
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM igitman/nemo-skills-vllm:0.6.0 as base
|
2 |
+
|
3 |
+
# Install NeMo-Skills and dependencies
|
4 |
+
RUN git clone https://github.com/NVIDIA/NeMo-Skills \
|
5 |
+
&& cd NeMo-Skills \
|
6 |
+
&& pip install --ignore-installed blinker \
|
7 |
+
&& pip install -e . \
|
8 |
+
&& pip install -r requirements/code_execution.txt
|
9 |
+
|
10 |
+
# Ensure python is available
|
11 |
+
RUN ln -s /usr/bin/python3 /usr/bin/python
|
12 |
+
|
13 |
+
# Copy our custom files
|
14 |
+
COPY handler.py server.py /usr/local/endpoint/
|
15 |
+
|
16 |
+
# Expose port 80
|
17 |
+
EXPOSE 80
|
18 |
+
|
19 |
+
# Copy and set up entrypoint script
|
20 |
+
COPY entrypoint.sh /usr/local/endpoint/
|
21 |
+
RUN chmod +x /usr/local/endpoint/entrypoint.sh
|
22 |
+
|
23 |
+
# Set working directory
|
24 |
+
WORKDIR /usr/local/endpoint
|
25 |
+
|
26 |
+
ENTRYPOINT ["/usr/local/endpoint/entrypoint.sh"]
|
entrypoint.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
|
4 |
+
# Default environment variables
|
5 |
+
export MODEL_PATH=${MODEL_PATH:-"/repository"}
|
6 |
+
|
7 |
+
echo "Starting NeMo Skills inference endpoint..."
|
8 |
+
echo "Model path: $MODEL_PATH"
|
9 |
+
|
10 |
+
# Function to handle cleanup on exit
|
11 |
+
cleanup() {
|
12 |
+
echo "Cleaning up processes..."
|
13 |
+
kill $(jobs -p) 2>/dev/null || true
|
14 |
+
wait
|
15 |
+
}
|
16 |
+
trap cleanup EXIT
|
17 |
+
|
18 |
+
# Start the model server in the background
|
19 |
+
echo "Starting model server..."
|
20 |
+
ns start_server \
|
21 |
+
--model="$MODEL_PATH" \
|
22 |
+
--server_gpus=2 \
|
23 |
+
--server_type=vllm \
|
24 |
+
--with_sandbox &
|
25 |
+
|
26 |
+
# Start the HTTP endpoint
|
27 |
+
echo "Starting HTTP endpoint on port 80..."
|
28 |
+
python /usr/local/endpoint/server.py &
|
29 |
+
|
30 |
+
# Wait for both processes
|
31 |
+
echo "Both servers started. Waiting..."
|
32 |
+
wait
|
handler.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import traceback
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
|
6 |
+
from nemo_skills.inference.server.code_execution_model import get_code_execution_model
|
7 |
+
from nemo_skills.code_execution.sandbox import get_sandbox
|
8 |
+
from nemo_skills.prompt.utils import get_prompt
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class EndpointHandler:
|
16 |
+
"""Custom endpoint handler for NeMo Skills code execution inference."""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
"""
|
20 |
+
Initialize the handler with the model and prompt configurations.
|
21 |
+
"""
|
22 |
+
self.model = None
|
23 |
+
self.prompt = None
|
24 |
+
self.initialized = False
|
25 |
+
|
26 |
+
# Configuration
|
27 |
+
self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
|
28 |
+
self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
|
29 |
+
|
30 |
+
def _initialize_components(self):
|
31 |
+
"""Initialize the model, sandbox, and prompt components lazily."""
|
32 |
+
if self.initialized:
|
33 |
+
return
|
34 |
+
|
35 |
+
try:
|
36 |
+
logger.info("Initializing sandbox...")
|
37 |
+
sandbox = get_sandbox(sandbox_type="local")
|
38 |
+
|
39 |
+
logger.info("Initializing code execution model...")
|
40 |
+
self.model = get_code_execution_model(
|
41 |
+
server_type="vllm",
|
42 |
+
sandbox=sandbox,
|
43 |
+
host="127.0.0.1",
|
44 |
+
port=5000
|
45 |
+
)
|
46 |
+
|
47 |
+
logger.info("Initializing prompt...")
|
48 |
+
if self.prompt_config_path:
|
49 |
+
self.prompt = get_prompt(
|
50 |
+
prompt_config=self.prompt_config_path,
|
51 |
+
prompt_template=self.prompt_template_path
|
52 |
+
)
|
53 |
+
|
54 |
+
self.initialized = True
|
55 |
+
logger.info("All components initialized successfully")
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
logger.warning(f"Failed to initialize the model")
|
59 |
+
|
60 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
61 |
+
"""
|
62 |
+
Process inference requests.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
data: Dictionary containing the request data
|
66 |
+
Expected keys:
|
67 |
+
- inputs: str or list of str - the input prompts/problems
|
68 |
+
- parameters: dict (optional) - generation parameters
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
List of dictionaries containing the generated responses
|
72 |
+
"""
|
73 |
+
try:
|
74 |
+
# Initialize components if not already done
|
75 |
+
self._initialize_components()
|
76 |
+
|
77 |
+
# Extract inputs and parameters
|
78 |
+
inputs = data.get("inputs", "")
|
79 |
+
parameters = data.get("parameters", {})
|
80 |
+
|
81 |
+
# Handle both single string and list of strings
|
82 |
+
if isinstance(inputs, str):
|
83 |
+
prompts = [inputs]
|
84 |
+
elif isinstance(inputs, list):
|
85 |
+
prompts = inputs
|
86 |
+
else:
|
87 |
+
raise ValueError("inputs must be a string or list of strings")
|
88 |
+
|
89 |
+
# If we have a prompt template configured, format the inputs
|
90 |
+
if self.prompt is not None:
|
91 |
+
formatted_prompts = []
|
92 |
+
for prompt_text in prompts:
|
93 |
+
formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
|
94 |
+
formatted_prompts.append(formatted_prompt)
|
95 |
+
prompts = formatted_prompts
|
96 |
+
|
97 |
+
# Get code execution arguments from prompt if available
|
98 |
+
extra_generate_params = {}
|
99 |
+
if self.prompt is not None:
|
100 |
+
extra_generate_params = self.prompt.get_code_execution_args()
|
101 |
+
|
102 |
+
# Set default generation parameters
|
103 |
+
generation_params = {
|
104 |
+
"tokens_to_generate": 12000,
|
105 |
+
"temperature": 0.0,
|
106 |
+
"top_p": 0.95,
|
107 |
+
"top_k": 0,
|
108 |
+
"repetition_penalty": 1.0,
|
109 |
+
"random_seed": 0,
|
110 |
+
}
|
111 |
+
|
112 |
+
# Update with provided parameters
|
113 |
+
generation_params.update(parameters)
|
114 |
+
generation_params.update(extra_generate_params)
|
115 |
+
|
116 |
+
logger.info(f"Processing {len(prompts)} prompt(s)")
|
117 |
+
|
118 |
+
# Generate responses
|
119 |
+
outputs = self.model.generate(
|
120 |
+
prompts=prompts,
|
121 |
+
**generation_params
|
122 |
+
)
|
123 |
+
|
124 |
+
# Format outputs
|
125 |
+
results = []
|
126 |
+
for output in outputs:
|
127 |
+
result = {
|
128 |
+
"generated_text": output.get("generation", ""),
|
129 |
+
"code_rounds_executed": output.get("code_rounds_executed", 0),
|
130 |
+
}
|
131 |
+
results.append(result)
|
132 |
+
|
133 |
+
logger.info(f"Successfully processed {len(results)} request(s)")
|
134 |
+
return results
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
logger.error(f"Error processing request: {str(e)}")
|
138 |
+
logger.error(traceback.format_exc())
|
139 |
+
return [{"error": str(e), "generated_text": ""}]
|
server.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from http.server import HTTPServer, BaseHTTPRequestHandler
|
4 |
+
|
5 |
+
from handler import EndpointHandler
|
6 |
+
|
7 |
+
# Configure logging
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# Initialize the handler
|
12 |
+
handler = EndpointHandler()
|
13 |
+
|
14 |
+
|
15 |
+
class RequestHandler(BaseHTTPRequestHandler):
|
16 |
+
def do_POST(self):
|
17 |
+
try:
|
18 |
+
content_length = int(self.headers['Content-Length'])
|
19 |
+
post_data = self.rfile.read(content_length)
|
20 |
+
data = json.loads(post_data.decode('utf-8'))
|
21 |
+
|
22 |
+
logger.info(f'Received request with {len(data.get("inputs", []))} inputs')
|
23 |
+
|
24 |
+
# Process the request
|
25 |
+
result = handler(data)
|
26 |
+
|
27 |
+
# Send response
|
28 |
+
self.send_response(200)
|
29 |
+
self.send_header('Content-Type', 'application/json')
|
30 |
+
self.end_headers()
|
31 |
+
self.wfile.write(json.dumps(result).encode('utf-8'))
|
32 |
+
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(f'Error processing request: {str(e)}')
|
35 |
+
self.send_response(500)
|
36 |
+
self.send_header('Content-Type', 'application/json')
|
37 |
+
self.end_headers()
|
38 |
+
error_response = [{'error': str(e), 'generated_text': ''}]
|
39 |
+
self.wfile.write(json.dumps(error_response).encode('utf-8'))
|
40 |
+
|
41 |
+
def do_GET(self):
|
42 |
+
if self.path == '/health':
|
43 |
+
# Trigger initialisation if needed but don't block.
|
44 |
+
if not handler.initialized:
|
45 |
+
try:
|
46 |
+
handler._initialize_components()
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f'Initialization failed during health check: {str(e)}')
|
49 |
+
|
50 |
+
is_ready = handler.initialized
|
51 |
+
health_response = {
|
52 |
+
'status': 'healthy' if is_ready else 'unhealthy',
|
53 |
+
'model_ready': is_ready
|
54 |
+
}
|
55 |
+
|
56 |
+
try:
|
57 |
+
self.send_response(200 if is_ready else 503)
|
58 |
+
self.send_header('Content-Type', 'application/json')
|
59 |
+
self.end_headers()
|
60 |
+
self.wfile.write(json.dumps(health_response).encode('utf-8'))
|
61 |
+
except BrokenPipeError:
|
62 |
+
# Client disconnected before we replied – safe to ignore.
|
63 |
+
pass
|
64 |
+
return
|
65 |
+
else:
|
66 |
+
self.send_response(404)
|
67 |
+
self.end_headers()
|
68 |
+
|
69 |
+
def log_message(self, format, *args):
|
70 |
+
# Suppress default HTTP server logs to keep output clean
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
server = HTTPServer(('0.0.0.0', 80), RequestHandler)
|
76 |
+
logger.info('HTTP server started on port 80')
|
77 |
+
server.serve_forever()
|