igitman imoshkov commited on
Commit
b1a665d
·
verified ·
1 Parent(s): a1ddb28

inference_endpoint (#2)

Browse files

- added endpoint scripts (1259b94c50e3fa4f31bce132de4f3ec7596aaa5b)


Co-authored-by: Ivan Moshkov <[email protected]>

Files changed (4) hide show
  1. Dockerfile +26 -0
  2. entrypoint.sh +32 -0
  3. handler.py +139 -0
  4. server.py +77 -0
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM igitman/nemo-skills-vllm:0.6.0 as base
2
+
3
+ # Install NeMo-Skills and dependencies
4
+ RUN git clone https://github.com/NVIDIA/NeMo-Skills \
5
+ && cd NeMo-Skills \
6
+ && pip install --ignore-installed blinker \
7
+ && pip install -e . \
8
+ && pip install -r requirements/code_execution.txt
9
+
10
+ # Ensure python is available
11
+ RUN ln -s /usr/bin/python3 /usr/bin/python
12
+
13
+ # Copy our custom files
14
+ COPY handler.py server.py /usr/local/endpoint/
15
+
16
+ # Expose port 80
17
+ EXPOSE 80
18
+
19
+ # Copy and set up entrypoint script
20
+ COPY entrypoint.sh /usr/local/endpoint/
21
+ RUN chmod +x /usr/local/endpoint/entrypoint.sh
22
+
23
+ # Set working directory
24
+ WORKDIR /usr/local/endpoint
25
+
26
+ ENTRYPOINT ["/usr/local/endpoint/entrypoint.sh"]
entrypoint.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Default environment variables
5
+ export MODEL_PATH=${MODEL_PATH:-"/repository"}
6
+
7
+ echo "Starting NeMo Skills inference endpoint..."
8
+ echo "Model path: $MODEL_PATH"
9
+
10
+ # Function to handle cleanup on exit
11
+ cleanup() {
12
+ echo "Cleaning up processes..."
13
+ kill $(jobs -p) 2>/dev/null || true
14
+ wait
15
+ }
16
+ trap cleanup EXIT
17
+
18
+ # Start the model server in the background
19
+ echo "Starting model server..."
20
+ ns start_server \
21
+ --model="$MODEL_PATH" \
22
+ --server_gpus=2 \
23
+ --server_type=vllm \
24
+ --with_sandbox &
25
+
26
+ # Start the HTTP endpoint
27
+ echo "Starting HTTP endpoint on port 80..."
28
+ python /usr/local/endpoint/server.py &
29
+
30
+ # Wait for both processes
31
+ echo "Both servers started. Waiting..."
32
+ wait
handler.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import traceback
4
+ from typing import Dict, List, Any
5
+
6
+ from nemo_skills.inference.server.code_execution_model import get_code_execution_model
7
+ from nemo_skills.code_execution.sandbox import get_sandbox
8
+ from nemo_skills.prompt.utils import get_prompt
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class EndpointHandler:
16
+ """Custom endpoint handler for NeMo Skills code execution inference."""
17
+
18
+ def __init__(self):
19
+ """
20
+ Initialize the handler with the model and prompt configurations.
21
+ """
22
+ self.model = None
23
+ self.prompt = None
24
+ self.initialized = False
25
+
26
+ # Configuration
27
+ self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
28
+ self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
29
+
30
+ def _initialize_components(self):
31
+ """Initialize the model, sandbox, and prompt components lazily."""
32
+ if self.initialized:
33
+ return
34
+
35
+ try:
36
+ logger.info("Initializing sandbox...")
37
+ sandbox = get_sandbox(sandbox_type="local")
38
+
39
+ logger.info("Initializing code execution model...")
40
+ self.model = get_code_execution_model(
41
+ server_type="vllm",
42
+ sandbox=sandbox,
43
+ host="127.0.0.1",
44
+ port=5000
45
+ )
46
+
47
+ logger.info("Initializing prompt...")
48
+ if self.prompt_config_path:
49
+ self.prompt = get_prompt(
50
+ prompt_config=self.prompt_config_path,
51
+ prompt_template=self.prompt_template_path
52
+ )
53
+
54
+ self.initialized = True
55
+ logger.info("All components initialized successfully")
56
+
57
+ except Exception as e:
58
+ logger.warning(f"Failed to initialize the model")
59
+
60
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
61
+ """
62
+ Process inference requests.
63
+
64
+ Args:
65
+ data: Dictionary containing the request data
66
+ Expected keys:
67
+ - inputs: str or list of str - the input prompts/problems
68
+ - parameters: dict (optional) - generation parameters
69
+
70
+ Returns:
71
+ List of dictionaries containing the generated responses
72
+ """
73
+ try:
74
+ # Initialize components if not already done
75
+ self._initialize_components()
76
+
77
+ # Extract inputs and parameters
78
+ inputs = data.get("inputs", "")
79
+ parameters = data.get("parameters", {})
80
+
81
+ # Handle both single string and list of strings
82
+ if isinstance(inputs, str):
83
+ prompts = [inputs]
84
+ elif isinstance(inputs, list):
85
+ prompts = inputs
86
+ else:
87
+ raise ValueError("inputs must be a string or list of strings")
88
+
89
+ # If we have a prompt template configured, format the inputs
90
+ if self.prompt is not None:
91
+ formatted_prompts = []
92
+ for prompt_text in prompts:
93
+ formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
94
+ formatted_prompts.append(formatted_prompt)
95
+ prompts = formatted_prompts
96
+
97
+ # Get code execution arguments from prompt if available
98
+ extra_generate_params = {}
99
+ if self.prompt is not None:
100
+ extra_generate_params = self.prompt.get_code_execution_args()
101
+
102
+ # Set default generation parameters
103
+ generation_params = {
104
+ "tokens_to_generate": 12000,
105
+ "temperature": 0.0,
106
+ "top_p": 0.95,
107
+ "top_k": 0,
108
+ "repetition_penalty": 1.0,
109
+ "random_seed": 0,
110
+ }
111
+
112
+ # Update with provided parameters
113
+ generation_params.update(parameters)
114
+ generation_params.update(extra_generate_params)
115
+
116
+ logger.info(f"Processing {len(prompts)} prompt(s)")
117
+
118
+ # Generate responses
119
+ outputs = self.model.generate(
120
+ prompts=prompts,
121
+ **generation_params
122
+ )
123
+
124
+ # Format outputs
125
+ results = []
126
+ for output in outputs:
127
+ result = {
128
+ "generated_text": output.get("generation", ""),
129
+ "code_rounds_executed": output.get("code_rounds_executed", 0),
130
+ }
131
+ results.append(result)
132
+
133
+ logger.info(f"Successfully processed {len(results)} request(s)")
134
+ return results
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error processing request: {str(e)}")
138
+ logger.error(traceback.format_exc())
139
+ return [{"error": str(e), "generated_text": ""}]
server.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from http.server import HTTPServer, BaseHTTPRequestHandler
4
+
5
+ from handler import EndpointHandler
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Initialize the handler
12
+ handler = EndpointHandler()
13
+
14
+
15
+ class RequestHandler(BaseHTTPRequestHandler):
16
+ def do_POST(self):
17
+ try:
18
+ content_length = int(self.headers['Content-Length'])
19
+ post_data = self.rfile.read(content_length)
20
+ data = json.loads(post_data.decode('utf-8'))
21
+
22
+ logger.info(f'Received request with {len(data.get("inputs", []))} inputs')
23
+
24
+ # Process the request
25
+ result = handler(data)
26
+
27
+ # Send response
28
+ self.send_response(200)
29
+ self.send_header('Content-Type', 'application/json')
30
+ self.end_headers()
31
+ self.wfile.write(json.dumps(result).encode('utf-8'))
32
+
33
+ except Exception as e:
34
+ logger.error(f'Error processing request: {str(e)}')
35
+ self.send_response(500)
36
+ self.send_header('Content-Type', 'application/json')
37
+ self.end_headers()
38
+ error_response = [{'error': str(e), 'generated_text': ''}]
39
+ self.wfile.write(json.dumps(error_response).encode('utf-8'))
40
+
41
+ def do_GET(self):
42
+ if self.path == '/health':
43
+ # Trigger initialisation if needed but don't block.
44
+ if not handler.initialized:
45
+ try:
46
+ handler._initialize_components()
47
+ except Exception as e:
48
+ logger.error(f'Initialization failed during health check: {str(e)}')
49
+
50
+ is_ready = handler.initialized
51
+ health_response = {
52
+ 'status': 'healthy' if is_ready else 'unhealthy',
53
+ 'model_ready': is_ready
54
+ }
55
+
56
+ try:
57
+ self.send_response(200 if is_ready else 503)
58
+ self.send_header('Content-Type', 'application/json')
59
+ self.end_headers()
60
+ self.wfile.write(json.dumps(health_response).encode('utf-8'))
61
+ except BrokenPipeError:
62
+ # Client disconnected before we replied – safe to ignore.
63
+ pass
64
+ return
65
+ else:
66
+ self.send_response(404)
67
+ self.end_headers()
68
+
69
+ def log_message(self, format, *args):
70
+ # Suppress default HTTP server logs to keep output clean
71
+ pass
72
+
73
+
74
+ if __name__ == "__main__":
75
+ server = HTTPServer(('0.0.0.0', 80), RequestHandler)
76
+ logger.info('HTTP server started on port 80')
77
+ server.serve_forever()