| | import json |
| | from collections.abc import Sequence |
| | from random import choices |
| | from string import ascii_letters, digits |
| | from typing import Optional, Union |
| |
|
| | import partial_json_parser |
| | import regex as re |
| | from partial_json_parser.core.options import Allow |
| | from pydantic import Field |
| |
|
| | from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
| | DeltaFunctionCall, DeltaMessage, |
| | DeltaToolCall, |
| | ExtractedToolCallInformation, |
| | FunctionCall, ToolCall) |
| | from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| | ToolParser, ToolParserManager) |
| | from vllm.logger import init_logger |
| | from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer |
| |
|
| | logger = init_logger(__name__) |
| |
|
| | ALPHANUMERIC = ascii_letters + digits |
| |
|
| |
|
| | class NemotronToolCall(ToolCall): |
| | id: str = Field( |
| | default_factory=lambda: NemotronToolCall.generate_random_id()) |
| |
|
| | @staticmethod |
| | def generate_random_id(): |
| | return "".join(choices(ALPHANUMERIC, k=9)) |
| |
|
| | @staticmethod |
| | def is_valid_id(id: str) -> bool: |
| | return id.isalnum() and len(id) == 9 |
| |
|
| |
|
| | def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: |
| | return isinstance(model_tokenizer, MistralTokenizer) \ |
| | and model_tokenizer.version >= 11 |
| |
|
| |
|
| | @ToolParserManager.register_module("nemotron_json") |
| | class NemotronToolParser(ToolParser): |
| | """ |
| | Tool call parser for Nemotron-Nano-V2 |
| | |
| | Used when --enable-auto-tool-choice --tool-call-parser nemotron_json are all set |
| | """ |
| |
|
| | def __init__(self, tokenizer: AnyTokenizer): |
| | super().__init__(tokenizer) |
| | |
| | |
| | self.prev_tool_call_arr: list[dict] = [] |
| | self.current_tool_id: int = -1 |
| | self.current_tool_name_sent: bool = False |
| | self.streamed_args_for_tool: list[str] = [ |
| | ] |
| | self.tool_args_emitted: list[bool] = [] |
| | self.bot_token = "<TOOLCALL>" |
| | self.bot_token_id = self.vocab.get(self.bot_token) |
| | logger.info(f"Nemotron Tool Parser: bot_token: {self.bot_token}, bot_token_id: {self.bot_token_id}") |
| | self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) |
| | if _is_fn_name_regex_support(self.model_tokenizer): |
| | self.fn_name_regex = re.compile( |
| | r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) |
| | else: |
| | self.fn_name_regex = None |
| |
|
| | |
| | |
| | self._pending_tag_buffer: str = "" |
| |
|
| | @staticmethod |
| | def _strip_trailing_auto_closers(chunk: str) -> str: |
| | """ |
| | Remove parser auto-completed closing braces/brackets plus trailing whitespace. |
| | These should be flushed only when a tool call completes to avoid duplicate |
| | argument fragments. |
| | """ |
| | idx = len(chunk) |
| | while idx > 0 and chunk[idx - 1] in " \t\r\n}]": |
| | idx -= 1 |
| | |
| | while idx > 0 and chunk[idx - 1] == '"': |
| | |
| | if idx - 2 >= 0 and chunk[idx - 2] == '\\': |
| | break |
| | idx -= 1 |
| | return chunk[:idx] |
| |
|
| | @staticmethod |
| | def _common_prefix_len(left: str, right: str) -> int: |
| | """ |
| | Return the length of the shared prefix between left and right strings. |
| | """ |
| | max_len = min(len(left), len(right)) |
| | idx = 0 |
| | while idx < max_len and left[idx] == right[idx]: |
| | idx += 1 |
| | return idx |
| |
|
| | def _compute_arguments_delta(self, cur_arguments_json: str, |
| | end_of_call: bool) -> str: |
| | """ |
| | Determine the incremental suffix to stream for the current tool call. |
| | Ensures we only emit monotonic chunks by trimming our tracked prefix to |
| | the longest common prefix with the latest JSON snapshot. |
| | """ |
| | tool_idx = self.current_tool_id |
| | if tool_idx < 0 or tool_idx >= len(self.streamed_args_for_tool): |
| | return "" |
| |
|
| | streamed_prefix = self.streamed_args_for_tool[tool_idx] |
| | had_any = (self.tool_args_emitted[tool_idx] |
| | if tool_idx < len(self.tool_args_emitted) else False) |
| |
|
| | lcp_len = self._common_prefix_len(cur_arguments_json, |
| | streamed_prefix) |
| | if lcp_len != len(streamed_prefix): |
| | streamed_prefix = streamed_prefix[:lcp_len] |
| | self.streamed_args_for_tool[tool_idx] = streamed_prefix |
| |
|
| | if (not had_any and not end_of_call and lcp_len == 0 |
| | and cur_arguments_json.endswith('": ""}') |
| | and '": ""' in cur_arguments_json): |
| | closing_pos = cur_arguments_json.rfind('": ""}') |
| | if closing_pos != -1: |
| | arguments_delta = cur_arguments_json[:closing_pos + 4] |
| | else: |
| | arguments_delta = cur_arguments_json |
| | else: |
| | arguments_delta = cur_arguments_json[lcp_len:] |
| |
|
| | if not arguments_delta: |
| | return "" |
| |
|
| | if not end_of_call: |
| | arguments_delta = self._strip_trailing_auto_closers( |
| | arguments_delta) |
| |
|
| | if (not had_any and not end_of_call and arguments_delta |
| | and arguments_delta.endswith('}')): |
| | arguments_delta = arguments_delta[:-1] |
| | if arguments_delta.endswith('"'): |
| | arguments_delta = arguments_delta[:-1] |
| |
|
| | return arguments_delta |
| |
|
| | def _visible_delta_outside_tool(self, delta_text: str, |
| | start_token: Optional[str], |
| | end_token: Optional[str]) -> str: |
| | """ |
| | Consume characters that could begin a tool tag. Only suppress the exact |
| | <TOOLCALL> / </TOOLCALL> sequences, and let everything else (e.g. </think>) |
| | pass through untouched. |
| | """ |
| | if not delta_text: |
| | return delta_text |
| |
|
| | visible: list[str] = [] |
| | for ch in delta_text: |
| | if self._pending_tag_buffer or ch == '<': |
| | self._pending_tag_buffer += ch |
| |
|
| | if start_token and start_token.startswith(self._pending_tag_buffer): |
| | if self._pending_tag_buffer == start_token: |
| | self._pending_tag_buffer = "" |
| | continue |
| |
|
| | if end_token and end_token.startswith(self._pending_tag_buffer): |
| | if self._pending_tag_buffer == end_token: |
| | self._pending_tag_buffer = "" |
| | continue |
| |
|
| | |
| | visible.append(self._pending_tag_buffer) |
| | self._pending_tag_buffer = "" |
| | else: |
| | visible.append(ch) |
| |
|
| | return "".join(visible) |
| |
|
| | def adjust_request( |
| | self, request: ChatCompletionRequest) -> ChatCompletionRequest: |
| | if not isinstance( |
| | self.model_tokenizer, MistralTokenizer |
| | ) and request.tools and request.tool_choice != 'none': |
| | |
| | |
| | |
| | |
| | |
| | request.skip_special_tokens = False |
| | return request |
| |
|
| | def extract_tool_calls( |
| | self, |
| | model_output: str, |
| | request: ChatCompletionRequest, |
| | ) -> ExtractedToolCallInformation: |
| | """ |
| | Extract the tool calls from a complete model response. Requires |
| | find-and-replacing single quotes with double quotes for JSON parsing, |
| | make sure your tool call arguments don't ever include quotes! |
| | """ |
| |
|
| | |
| | if self.bot_token not in model_output: |
| | return ExtractedToolCallInformation(tools_called=False, |
| | tool_calls=[], |
| | content=model_output) |
| |
|
| | |
| | tool_content = model_output.replace(self.bot_token, "").strip() |
| |
|
| | try: |
| | |
| | |
| | try: |
| | if self.fn_name_regex: |
| | matches = self.fn_name_regex.findall(tool_content) |
| |
|
| | function_call_arr = [] |
| | for match in matches: |
| | fn_name = match[0] |
| | args = match[1] |
| |
|
| | |
| | |
| | function_call_arr.append({ |
| | "name": fn_name, |
| | "arguments": json.loads(args) |
| | }) |
| | else: |
| | function_call_arr = json.loads(tool_content) |
| | except json.JSONDecodeError: |
| | |
| | |
| | |
| | |
| | raw_tool_call = self.tool_call_regex.findall(tool_content)[0] |
| | function_call_arr = json.loads(raw_tool_call) |
| |
|
| | |
| | tool_calls: list[NemotronToolCall] = [ |
| | NemotronToolCall( |
| | type="function", |
| | function=FunctionCall( |
| | name=raw_function_call["name"], |
| | |
| | arguments=json.dumps(raw_function_call["arguments"], |
| | ensure_ascii=False))) |
| | for raw_function_call in function_call_arr |
| | ] |
| |
|
| | |
| | content = model_output.split(self.bot_token)[0] |
| | return ExtractedToolCallInformation( |
| | tools_called=True, |
| | tool_calls=tool_calls, |
| | content=content if len(content) > 0 else None) |
| |
|
| | except Exception: |
| | logger.exception("Error in extracting tool call from response.") |
| | |
| | return ExtractedToolCallInformation(tools_called=False, |
| | tool_calls=[], |
| | content=tool_content) |
| |
|
| | def extract_tool_calls_streaming( |
| | self, |
| | previous_text: str, |
| | current_text: str, |
| | delta_text: str, |
| | previous_token_ids: Sequence[int], |
| | current_token_ids: Sequence[int], |
| | delta_token_ids: Sequence[int], |
| | request: ChatCompletionRequest, |
| | ) -> Union[DeltaMessage, None]: |
| | |
| | |
| | |
| | visible_delta_text = delta_text |
| | try: |
| | start_token = self.bot_token |
| | end_token = f"</{self.bot_token[1:]}" if self.bot_token.startswith('<') else None |
| |
|
| | visible_delta_text = self._visible_delta_outside_tool( |
| | delta_text, start_token, end_token) |
| | except Exception: |
| | |
| | if current_text.endswith('<') or current_text.endswith('<T') or current_text.endswith('<TO') or current_text.endswith('<TOOL') or current_text.endswith('<TOOLCALL'): |
| | return None |
| |
|
| | |
| | |
| | if self.bot_token not in current_text: |
| | if visible_delta_text: |
| | return DeltaMessage(content=visible_delta_text) |
| | |
| | return None |
| |
|
| | |
| | |
| | |
| | |
| | flags = Allow.ALL if self.current_tool_name_sent \ |
| | else Allow.ALL & ~Allow.STR |
| | end_of_call: bool = False |
| | try: |
| |
|
| | |
| | |
| | |
| | parsable_arr = current_text.split(self.bot_token)[-1] |
| | |
| | |
| | if '</TOOLCALL>' in parsable_arr: |
| | end_of_call = True |
| | parsable_arr = parsable_arr.split('</TOOLCALL>')[0] |
| |
|
| | |
| | |
| | try: |
| | tool_call_arr: list[dict] = partial_json_parser.loads( |
| | parsable_arr, flags) |
| | except (partial_json_parser.core.exceptions.MalformedJSON, |
| | json.JSONDecodeError, ValueError): |
| | return None |
| |
|
| | current_tool_call: dict = tool_call_arr[self.current_tool_id] \ |
| | if len(tool_call_arr) > 0 else {} |
| |
|
| | |
| | |
| | if len(tool_call_arr) == 0: |
| | return None |
| |
|
| | |
| | |
| | elif (len(tool_call_arr) > 0 |
| | and len(tool_call_arr) > self.current_tool_id + 1): |
| |
|
| | |
| | |
| | |
| | |
| | if self.current_tool_id >= 0: |
| | diff: Union[str, None] = current_tool_call.get("arguments") |
| |
|
| | if diff: |
| | diff = json.dumps(diff, ensure_ascii=False).replace( |
| | self.streamed_args_for_tool[self.current_tool_id], |
| | "") |
| | delta = DeltaMessage(tool_calls=[ |
| | DeltaToolCall(index=self.current_tool_id, |
| | function=DeltaFunctionCall( |
| | arguments=diff).model_dump( |
| | exclude_none=True)) |
| | ]) |
| | self.streamed_args_for_tool[ |
| | self.current_tool_id] += diff |
| | else: |
| | delta = None |
| | else: |
| | delta = None |
| | |
| | self.current_tool_id = len(tool_call_arr) - 1 |
| | self.current_tool_name_sent = False |
| | self.streamed_args_for_tool.append("") |
| | self.tool_args_emitted.append(False) |
| | return delta |
| |
|
| | |
| |
|
| | |
| | |
| | if not self.current_tool_name_sent: |
| | function_name = current_tool_call.get("name") |
| | if function_name: |
| |
|
| | delta = DeltaMessage(tool_calls=[ |
| | DeltaToolCall(index=self.current_tool_id, |
| | type="function", |
| | id=NemotronToolCall.generate_random_id(), |
| | function=DeltaFunctionCall( |
| | name=function_name).model_dump( |
| | exclude_none=True)) |
| | ]) |
| | self.current_tool_name_sent = True |
| | else: |
| | delta = None |
| |
|
| | |
| | |
| | else: |
| |
|
| | prev_arguments = self.prev_tool_call_arr[ |
| | self.current_tool_id].get("arguments") |
| | cur_arguments = current_tool_call.get("arguments") |
| |
|
| | if not cur_arguments and not prev_arguments: |
| |
|
| | delta = None |
| | elif not cur_arguments and prev_arguments: |
| | logger.error( |
| | "INVARIANT - impossible to have arguments reset " |
| | "mid-arguments") |
| | delta = None |
| | elif cur_arguments: |
| | cur_arguments_json = json.dumps(cur_arguments, |
| | ensure_ascii=False) |
| | arguments_delta = self._compute_arguments_delta( |
| | cur_arguments_json, end_of_call) |
| | if arguments_delta: |
| | delta = DeltaMessage(tool_calls=[ |
| | DeltaToolCall(index=self.current_tool_id, |
| | function=DeltaFunctionCall( |
| | arguments=arguments_delta). |
| | model_dump(exclude_none=True)) |
| | ]) |
| | self.streamed_args_for_tool[ |
| | self.current_tool_id] += arguments_delta |
| | self.tool_args_emitted[ |
| | self.current_tool_id] = True |
| | else: |
| | |
| | |
| | delta = None |
| | else: |
| | |
| | delta = None |
| |
|
| | |
| | |
| | |
| | self.prev_tool_call_arr = tool_call_arr |
| | |
| | |
| | if end_of_call and self.current_tool_id >= 0: |
| | try: |
| | cur_arguments = current_tool_call.get("arguments") |
| | if cur_arguments is not None: |
| | cur_args_json = json.dumps(cur_arguments, |
| | ensure_ascii=False) |
| | remaining_suffix = self._compute_arguments_delta( |
| | cur_args_json, end_of_call=True) |
| |
|
| | |
| | |
| | if remaining_suffix and remaining_suffix.strip(): |
| | extra = DeltaToolCall( |
| | index=self.current_tool_id, |
| | function=DeltaFunctionCall( |
| | arguments=remaining_suffix).model_dump( |
| | exclude_none=True)) |
| | if delta is None: |
| | delta = DeltaMessage(tool_calls=[extra]) |
| | else: |
| | if getattr(delta, "tool_calls", None): |
| | delta.tool_calls.append(extra) |
| | else: |
| | delta.tool_calls = [extra] |
| | self.streamed_args_for_tool[ |
| | self.current_tool_id] += remaining_suffix |
| | self.tool_args_emitted[self.current_tool_id] = True |
| | else: |
| | pass |
| | except Exception: |
| | pass |
| |
|
| | return delta |
| |
|
| | except Exception: |
| | logger.exception("Error trying to handle streaming tool call.") |
| | return None |
| |
|