# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from copy import deepcopy from unittest.mock import MagicMock import pytest import regex as re from pydantic import TypeAdapter from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionToolsParam) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat EXAMPLE_TOOLS = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { "type": "object", "properties": { "city": { "type": "string", "description": "The city to find the weather for" ", e.g. 'San Francisco'", }, }, "required": ["city"], "additionalProperties": False }, }, "strict": True }, { "type": "function", "function": { "name": "get_forecast", "description": "Get the weather forecast for a given location", "parameters": { "type": "object", "properties": { "city": { "type": "string", "description": "The city to get the forecast for, e.g. 'New York'", }, "days": { "type": "integer", "description": "Number of days to get the forecast for (1-7)", }, }, "required": ["city", "days"], "additionalProperties": False }, }, "strict": True }, ] def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, should_match: bool): self = MagicMock(tool_choice="required", tools=tools) schema = ChatCompletionRequest._get_guided_json_from_tool(self) assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide from outlines_core.fsm.json_schema import build_regex_from_schema regex = build_regex_from_schema(json.dumps(schema)) compiled = re.compile(regex) matches = compiled.fullmatch(json.dumps(sample_output)) is not None assert matches == should_match VALID_TOOL_OUTPUTS = [ ([{ "name": "get_current_weather", "parameters": { "city": "Vienna" } }], True), ([{ "name": "get_current_weather", "parameters": { "city": "Vienna" } }, { "name": "get_current_weather", "parameters": { "city": "Berlin" } }], True), ([{ "name": "get_forecast", "parameters": { "city": "Vienna", "days": 7 } }], True), ([{ "name": "get_forecast", "parameters": { "city": "Vienna", "days": 7 } }, { "name": "get_current_weather", "parameters": { "city": "Vienna" } }], True), ([{ "name": "get_forecast", "parameters": { "city": "Vienna", "days": 7 } }, { "name": "get_current_weather", "parameters": { "city": "Vienna" } }, { "name": "get_forecast", "parameters": { "city": "Berlin", "days": 7 } }, { "name": "get_current_weather", "parameters": { "city": "Berlin" } }], True), ] VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS] @pytest.mark.parametrize( "sample_output, should_match", VALID_TOOL_OUTPUTS + [ (None, False), ([], False), # empty list cannot be generated ({}, False), # empty object cannot be generated ([{}], False), # list with empty object cannot be generated ( [{ # function without required parameters cannot be generated "name": "get_current_weather" }], False), ( [{ # function without required parameters cannot be generated "name": "get_current_weather", "parameters": {} }], False), ( [{ # function without required parameters cannot be generated "name": "get_current_weather", "parameters": None }], False), ( { # tool call without lists cannot be generated "name": "get_current_weather", "parameters": { "city": "Vienna" } }, False), ( [{ # tool call with extra parameters cannot be generated "name": "get_current_weather", "parameters": { "city": "Vienna", "extra": "value" } }], False), ( [{ # tool call where parameters are first cannot be generated "parameters": { "city": "Vienna" }, "name": "get_current_weather" }], False), ( [{ # tool call without all required parameters cannot be generated "name": "get_forecast", "parameters": { "city": "Vienna" } }], False), ( # tool call with incorrect name/parameters cannot be generated [{ "name": "get_weather", "parameters": { "city": "Vienna", "days": 7 } }], False), ( # tool call with both valid and empty function cannot be generated [{ "name": "get_current_weather", "parameters": { "city": "Vienna" } }, {}], False), ]) def test_guided_json(sample_output, should_match): _compile_and_check(tools=TypeAdapter( list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS), sample_output=sample_output, should_match=should_match) def update_parameters_none( tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: tool.function.parameters = None return tool def update_parameters_empty_dict( tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: tool.function.parameters = {} return tool @pytest.mark.parametrize( "sample_output, should_match", [ (None, False), ([], False), # empty list cannot be generated ({}, False), # empty object cannot be generated ([{}], False), # list with empty object cannot be generated ( [{ # function without required parameters cannot be generated "name": "get_current_weather" }], False), ( [{ # function without required parameters cannot be generated "name": "get_current_weather", "parameters": None }], False), ( [{ # function with extra parameters cannot be generated "name": "get_current_weather", "parameters": { "extra": "value" } }], False), ( [{ # only function with empty parameters object is valid "name": "get_current_weather", "parameters": {} }], True), ]) @pytest.mark.parametrize( "update_parameters", [update_parameters_none, update_parameters_empty_dict]) def test_guided_json_without_parameters(sample_output, should_match, update_parameters): updated_tools = [deepcopy(EXAMPLE_TOOLS[0])] tools = TypeAdapter( list[ChatCompletionToolsParam]).validate_python(updated_tools) tools = list(map(update_parameters, tools)) assert all([ tool.function.parameters is None or tool.function.parameters == {} for tool in tools ]) _compile_and_check(tools=tools, sample_output=sample_output, should_match=should_match) @pytest.mark.parametrize("output", VALID_TOOLS) @pytest.mark.parametrize("empty_params", [False, True]) @pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) def test_streaming_output_valid(output, empty_params, delta_len): self = MagicMock() output = deepcopy(output) if empty_params: output = [{"name": o["name"], "parameters": {}} for o in output] output_json = json.dumps(output) previous_text = "" function_name_returned = False messages = [] for i in range(0, len(output_json), delta_len): delta_text = output_json[i:i + delta_len] current_text = previous_text + delta_text delta_message, function_name_returned = ( OpenAIServingChat.extract_tool_call_required_streaming( self, previous_text=previous_text, current_text=current_text, delta_text=delta_text, function_name_returned=function_name_returned)) if delta_message: messages.append(delta_message) previous_text = current_text assert len(messages) > 0 combined_messages = "[" for message in messages: if message.tool_calls[0].function.name: if len(combined_messages) > 1: combined_messages += "}," combined_messages += '{"name": "' + \ message.tool_calls[0].function.name + \ '", "parameters": ' + \ message.tool_calls[0].function.arguments else: combined_messages += message.tool_calls[0].function.arguments combined_messages += "}]" assert json.loads(combined_messages) == output assert json.dumps(json.loads(combined_messages)) == output_json