|
import ast |
|
import json |
|
import re |
|
from collections.abc import Sequence |
|
from typing import Union |
|
|
|
import partial_json_parser |
|
from partial_json_parser.core.options import Allow |
|
|
|
from vllm.entrypoints.openai.protocol import ( |
|
ChatCompletionRequest, |
|
DeltaFunctionCall, DeltaMessage, |
|
DeltaToolCall, |
|
ExtractedToolCallInformation, |
|
FunctionCall, |
|
ToolCall, |
|
) |
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
|
ToolParser, |
|
ToolParserManager, |
|
) |
|
from vllm.logger import init_logger |
|
from vllm.transformers_utils.tokenizer import AnyTokenizer |
|
from vllm.utils import random_uuid |
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
@ToolParserManager.register_module("llama_nemotron_xml") |
|
class LlamaNemotronXMLToolParser(ToolParser): |
|
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
super().__init__(tokenizer) |
|
|
|
self.current_tool_name_sent: bool = False |
|
self.prev_tool_call_arr: list[dict] = [] |
|
self.current_tool_id: int = -1 |
|
self.streamed_args_for_tool: list[str] = [] |
|
|
|
self.tool_call_start_token: str = "<tool_call>" |
|
self.tool_call_end_token: str = "</tool_call>" |
|
|
|
|
|
self.tool_call_block_regex = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL) |
|
|
|
self.name_regex = re.compile(r"<tool>(.*?)</tool>", re.DOTALL) |
|
|
|
self.param_regex = re.compile(r"<([^/>\s]+)>(.*?)</\1>", re.DOTALL) |
|
|
|
def extract_tool_calls( |
|
self, |
|
model_output: str, |
|
request: ChatCompletionRequest, |
|
) -> ExtractedToolCallInformation: |
|
|
|
tool_call_start_index = model_output.find(self.tool_call_start_token) |
|
|
|
if tool_call_start_index == -1: |
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
content = model_output[:tool_call_start_index].strip() |
|
tool_calls_str_content = model_output[tool_call_start_index:] |
|
|
|
parsed_tool_calls = [] |
|
|
|
try: |
|
|
|
xml_tool_call_contents = self.tool_call_block_regex.findall(tool_calls_str_content) |
|
|
|
for tool_content_str in xml_tool_call_contents: |
|
name_match = self.name_regex.search(tool_content_str) |
|
if not name_match: |
|
logger.warning(f"Could not find tool name in XML block: {tool_content_str}") |
|
continue |
|
tool_name = name_match.group(1).strip() |
|
|
|
parsed_arguments = {} |
|
|
|
|
|
param_matches = self.param_regex.finditer(tool_content_str) |
|
|
|
for match in param_matches: |
|
param_name = match.group(1).strip() |
|
param_value_str = match.group(2).strip() |
|
|
|
|
|
if param_name == "tool": |
|
continue |
|
|
|
target_type = None |
|
|
|
if request.tools: |
|
for tool_def in request.tools: |
|
if tool_def.function.name == tool_name: |
|
if tool_def.function.parameters and \ |
|
isinstance(tool_def.function.parameters, dict) and \ |
|
"properties" in tool_def.function.parameters and \ |
|
isinstance(tool_def.function.parameters["properties"], dict) and \ |
|
param_name in tool_def.function.parameters["properties"] and \ |
|
isinstance(tool_def.function.parameters["properties"][param_name], dict): |
|
target_type = tool_def.function.parameters["properties"][param_name].get("type") |
|
break |
|
|
|
typed_param_value = param_value_str |
|
if target_type: |
|
try: |
|
if target_type == "string": |
|
typed_param_value = param_value_str |
|
elif target_type == "integer": |
|
typed_param_value = int(param_value_str) |
|
elif target_type == "number": |
|
typed_param_value = float(param_value_str) |
|
elif target_type == "boolean": |
|
typed_param_value = param_value_str.lower() == 'true' |
|
elif target_type in ["object", "array"]: |
|
try: |
|
typed_param_value = json.loads(param_value_str) |
|
except json.JSONDecodeError: |
|
|
|
typed_param_value = ast.literal_eval(param_value_str) |
|
else: |
|
typed_param_value = param_value_str |
|
except (ValueError, SyntaxError, json.JSONDecodeError) as e: |
|
logger.warning( |
|
f"Could not convert param '{param_name}' with value '{param_value_str}' " |
|
f"to type '{target_type}'. Error: {e}. Using string value." |
|
) |
|
typed_param_value = param_value_str |
|
else: |
|
try: |
|
|
|
|
|
if (param_value_str.startswith("'") and param_value_str.endswith("'")) or \ |
|
(param_value_str.startswith('"') and param_value_str.endswith('"')) or \ |
|
(param_value_str.startswith('[') and param_value_str.endswith(']')) or \ |
|
(param_value_str.startswith('{') and param_value_str.endswith('}')) or \ |
|
param_value_str.lower() in ['true', 'false', 'none'] or \ |
|
param_value_str.replace('.', '', 1).isdigit() or \ |
|
(param_value_str.startswith('-') and param_value_str[1:].replace('.', '', 1).isdigit()): |
|
typed_param_value = ast.literal_eval(param_value_str) |
|
else: |
|
typed_param_value = param_value_str |
|
except (ValueError, SyntaxError): |
|
typed_param_value = param_value_str |
|
|
|
parsed_arguments[param_name] = typed_param_value |
|
|
|
parsed_tool_calls.append(ToolCall( |
|
id=f"call_{random_uuid()}", |
|
type="function", |
|
function=FunctionCall( |
|
name=tool_name, |
|
arguments=json.dumps(parsed_arguments, ensure_ascii=False), |
|
), |
|
)) |
|
|
|
return ExtractedToolCallInformation( |
|
tools_called=len(parsed_tool_calls) > 0, |
|
tool_calls=parsed_tool_calls, |
|
content=content if content else None, |
|
) |
|
|
|
except Exception: |
|
logger.exception(f"Error in extracting XML tool call from response. Response: {model_output}") |
|
|
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
def extract_tool_calls_streaming( |
|
self, |
|
previous_text: str, |
|
current_text: str, |
|
delta_text: str, |
|
previous_token_ids: Sequence[int], |
|
current_token_ids: Sequence[int], |
|
delta_token_ids: Sequence[int], |
|
request: ChatCompletionRequest, |
|
) -> Union[DeltaMessage, None]: |
|
|
|
raise NotImplementedError("Tool calling is not supported in streaming mode!") |
|
|
|
|
|
@ToolParserManager.register_module("llama_nemotron_json") |
|
class LlamaNemotronJSONToolParser(ToolParser): |
|
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
super().__init__(tokenizer) |
|
|
|
self.current_tool_name_sent: bool = False |
|
self.prev_tool_call_arr: list[dict] = [] |
|
self.current_tool_id: int = -1 |
|
self.streamed_args_for_tool: list[str] = [] |
|
|
|
self.tool_call_start_token: str = "<TOOLCALL>" |
|
self.tool_call_end_token: str = "</TOOLCALL>" |
|
|
|
self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL) |
|
|
|
def extract_tool_calls( |
|
self, |
|
model_output: str, |
|
request: ChatCompletionRequest, |
|
) -> ExtractedToolCallInformation: |
|
|
|
if self.tool_call_start_token not in model_output: |
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
else: |
|
|
|
try: |
|
str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip() |
|
if not str_tool_calls.startswith("["): |
|
str_tool_calls = "[" + str_tool_calls |
|
if not str_tool_calls.endswith("]"): |
|
str_tool_calls = "]" + str_tool_calls |
|
json_tool_calls = json.loads(str_tool_calls) |
|
tool_calls = [] |
|
for tool_call in json_tool_calls: |
|
try: |
|
tool_calls.append(ToolCall( |
|
type="function", |
|
function=FunctionCall( |
|
name=tool_call["name"], |
|
arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \ |
|
if isinstance(tool_call["arguments"], dict) else tool_call["arguments"], |
|
), |
|
)) |
|
except: |
|
continue |
|
|
|
content = model_output[:model_output.rfind(self.tool_call_start_token)] |
|
|
|
return ExtractedToolCallInformation( |
|
tools_called=True, |
|
tool_calls=tool_calls, |
|
content=content if content else None, |
|
) |
|
|
|
except Exception: |
|
logger.exception(f"Error in extracting tool call from response. Response: {model_output}") |
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
def extract_tool_calls_streaming( |
|
self, |
|
previous_text: str, |
|
current_text: str, |
|
delta_text: str, |
|
previous_token_ids: Sequence[int], |
|
current_token_ids: Sequence[int], |
|
delta_token_ids: Sequence[int], |
|
request: ChatCompletionRequest, |
|
) -> Union[DeltaMessage, None]: |
|
|
|
raise NotImplementedError("Tool calling is not supported in streaming mode!") |
|
|
|
|
|
@ToolParserManager.register_module("llama_nemotron_pythonic") |
|
class LlamaNemotronPythonicToolParser(ToolParser): |
|
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
super().__init__(tokenizer) |
|
|
|
self.current_tool_name_sent: bool = False |
|
self.prev_tool_call_arr: list[dict] = [] |
|
self.current_tool_id: int = -1 |
|
self.streamed_args_for_tool: list[str] = [] |
|
|
|
self.tool_call_start_token: str = "<TOOLCALL>" |
|
self.tool_call_end_token: str = "</TOOLCALL>" |
|
|
|
self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL) |
|
|
|
self.function_call_regex = re.compile(r"(\w+)\((.*?)\)$", re.DOTALL) |
|
|
|
def parse_function_arguments(self, args_str: str) -> dict: |
|
"""Parse pythonic function arguments string into a dictionary""" |
|
if not args_str.strip(): |
|
return {} |
|
|
|
|
|
|
|
try: |
|
|
|
dummy_code = f"dummy_func({args_str})" |
|
parsed = ast.parse(dummy_code, mode='eval') |
|
|
|
|
|
call_node = parsed.body |
|
if not isinstance(call_node, ast.Call): |
|
return {} |
|
|
|
arguments = {} |
|
|
|
|
|
for keyword in call_node.keywords: |
|
if keyword.arg is None: |
|
continue |
|
|
|
|
|
try: |
|
value = ast.literal_eval(keyword.value) |
|
arguments[keyword.arg] = value |
|
except (ValueError, TypeError): |
|
|
|
if isinstance(keyword.value, ast.Name): |
|
arguments[keyword.arg] = keyword.value.id |
|
elif isinstance(keyword.value, ast.Constant): |
|
arguments[keyword.arg] = keyword.value.value |
|
else: |
|
|
|
arguments[keyword.arg] = ast.unparse(keyword.value) |
|
|
|
|
|
for i, arg in enumerate(call_node.args): |
|
try: |
|
value = ast.literal_eval(arg) |
|
arguments[f"arg_{i}"] = value |
|
except (ValueError, TypeError): |
|
if isinstance(arg, ast.Name): |
|
arguments[f"arg_{i}"] = arg.id |
|
elif isinstance(arg, ast.Constant): |
|
arguments[f"arg_{i}"] = arg.value |
|
else: |
|
arguments[f"arg_{i}"] = ast.unparse(arg) |
|
|
|
return arguments |
|
|
|
except (SyntaxError, ValueError) as e: |
|
logger.warning(f"Failed to parse function arguments '{args_str}': {e}") |
|
return {} |
|
|
|
def extract_tool_calls( |
|
self, |
|
model_output: str, |
|
request: ChatCompletionRequest, |
|
) -> ExtractedToolCallInformation: |
|
|
|
if self.tool_call_start_token not in model_output: |
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
tool_call_start_index = model_output.find(self.tool_call_start_token) |
|
content = model_output[:tool_call_start_index].strip() |
|
|
|
try: |
|
|
|
tool_call_matches = self.tool_call_regex.findall(model_output) |
|
if not tool_call_matches: |
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
tool_calls_content = tool_call_matches[0].strip() |
|
|
|
|
|
function_lines = [line.strip() for line in tool_calls_content.split('\n') if line.strip()] |
|
|
|
parsed_tool_calls = [] |
|
|
|
for func_line in function_lines: |
|
|
|
match = self.function_call_regex.match(func_line) |
|
if not match: |
|
logger.warning(f"Could not parse function call: {func_line}") |
|
continue |
|
|
|
function_name = match.group(1) |
|
args_str = match.group(2) |
|
|
|
|
|
parsed_arguments = self.parse_function_arguments(args_str) |
|
|
|
|
|
if request.tools: |
|
for tool_def in request.tools: |
|
if tool_def.function.name == function_name: |
|
schema_properties = {} |
|
if (tool_def.function.parameters and |
|
isinstance(tool_def.function.parameters, dict) and |
|
"properties" in tool_def.function.parameters and |
|
isinstance(tool_def.function.parameters["properties"], dict)): |
|
schema_properties = tool_def.function.parameters["properties"] |
|
|
|
|
|
for arg_name, arg_value in parsed_arguments.items(): |
|
if arg_name in schema_properties: |
|
param_info = schema_properties[arg_name] |
|
target_type = param_info.get("type") |
|
|
|
try: |
|
if target_type == "string" and not isinstance(arg_value, str): |
|
parsed_arguments[arg_name] = str(arg_value) |
|
elif target_type == "integer" and not isinstance(arg_value, int): |
|
parsed_arguments[arg_name] = int(arg_value) |
|
elif target_type == "number" and not isinstance(arg_value, (int, float)): |
|
parsed_arguments[arg_name] = float(arg_value) |
|
elif target_type == "boolean" and not isinstance(arg_value, bool): |
|
if isinstance(arg_value, str): |
|
parsed_arguments[arg_name] = arg_value.lower() in ['true', '1', 'yes'] |
|
else: |
|
parsed_arguments[arg_name] = bool(arg_value) |
|
elif target_type in ["object", "array"]: |
|
if isinstance(arg_value, str): |
|
try: |
|
parsed_arguments[arg_name] = json.loads(arg_value) |
|
except json.JSONDecodeError: |
|
|
|
pass |
|
except (ValueError, TypeError) as e: |
|
logger.warning(f"Type conversion failed for {arg_name}: {e}") |
|
|
|
break |
|
|
|
parsed_tool_calls.append(ToolCall( |
|
id=f"call_{random_uuid()}", |
|
type="function", |
|
function=FunctionCall( |
|
name=function_name, |
|
arguments=json.dumps(parsed_arguments, ensure_ascii=False), |
|
), |
|
)) |
|
|
|
return ExtractedToolCallInformation( |
|
tools_called=len(parsed_tool_calls) > 0, |
|
tool_calls=parsed_tool_calls, |
|
content=content if content else None, |
|
) |
|
|
|
except Exception: |
|
logger.exception(f"Error in extracting pythonic tool call from response. Response: {model_output}") |
|
return ExtractedToolCallInformation( |
|
tools_called=False, |
|
tool_calls=[], |
|
content=model_output, |
|
) |
|
|
|
def extract_tool_calls_streaming( |
|
self, |
|
previous_text: str, |
|
current_text: str, |
|
delta_text: str, |
|
previous_token_ids: Sequence[int], |
|
current_token_ids: Sequence[int], |
|
delta_token_ids: Sequence[int], |
|
request: ChatCompletionRequest, |
|
) -> Union[DeltaMessage, None]: |
|
|
|
raise NotImplementedError("Tool calling is not supported in streaming mode!") |
|
|