File size: 5,124 Bytes
b1a665d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import logging
import traceback
from typing import Dict, List, Any

from nemo_skills.inference.server.code_execution_model import get_code_execution_model
from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.prompt.utils import get_prompt

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EndpointHandler:
    """Custom endpoint handler for NeMo Skills code execution inference."""
    
    def __init__(self):
        """
        Initialize the handler with the model and prompt configurations.
        """
        self.model = None
        self.prompt = None
        self.initialized = False
        
        # Configuration
        self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
        self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
        
    def _initialize_components(self):
        """Initialize the model, sandbox, and prompt components lazily."""
        if self.initialized:
            return
            
        try:
            logger.info("Initializing sandbox...")
            sandbox = get_sandbox(sandbox_type="local")
            
            logger.info("Initializing code execution model...")
            self.model = get_code_execution_model(
                server_type="vllm",
                sandbox=sandbox,
                host="127.0.0.1",
                port=5000
            )
            
            logger.info("Initializing prompt...")
            if self.prompt_config_path:
                self.prompt = get_prompt(
                    prompt_config=self.prompt_config_path,
                    prompt_template=self.prompt_template_path
                )
            
            self.initialized = True
            logger.info("All components initialized successfully")
            
        except Exception as e:
            logger.warning(f"Failed to initialize the model")
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process inference requests.
        
        Args:
            data: Dictionary containing the request data
                Expected keys:
                - inputs: str or list of str - the input prompts/problems
                - parameters: dict (optional) - generation parameters
                
        Returns:
            List of dictionaries containing the generated responses
        """
        try:
            # Initialize components if not already done
            self._initialize_components()
            
            # Extract inputs and parameters
            inputs = data.get("inputs", "")
            parameters = data.get("parameters", {})
            
            # Handle both single string and list of strings
            if isinstance(inputs, str):
                prompts = [inputs]
            elif isinstance(inputs, list):
                prompts = inputs
            else:
                raise ValueError("inputs must be a string or list of strings")
            
            # If we have a prompt template configured, format the inputs
            if self.prompt is not None:
                formatted_prompts = []
                for prompt_text in prompts:
                    formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
                    formatted_prompts.append(formatted_prompt)
                prompts = formatted_prompts
            
            # Get code execution arguments from prompt if available
            extra_generate_params = {}
            if self.prompt is not None:
                extra_generate_params = self.prompt.get_code_execution_args()
            
            # Set default generation parameters
            generation_params = {
                "tokens_to_generate": 12000,
                "temperature": 0.0,
                "top_p": 0.95,
                "top_k": 0,
                "repetition_penalty": 1.0,
                "random_seed": 0,
            }
            
            # Update with provided parameters
            generation_params.update(parameters)
            generation_params.update(extra_generate_params)
            
            logger.info(f"Processing {len(prompts)} prompt(s)")
            
            # Generate responses
            outputs = self.model.generate(
                prompts=prompts,
                **generation_params
            )
            
            # Format outputs
            results = []
            for output in outputs:
                result = {
                    "generated_text": output.get("generation", ""),
                    "code_rounds_executed": output.get("code_rounds_executed", 0),
                }
                results.append(result)
            
            logger.info(f"Successfully processed {len(results)} request(s)")
            return results
            
        except Exception as e:
            logger.error(f"Error processing request: {str(e)}")
            logger.error(traceback.format_exc())
            return [{"error": str(e), "generated_text": ""}]