tool-call: add weather tool e2e tests
This commit is contained in:
parent
15ec01e896
commit
2efa0c27bf
1 changed files with 86 additions and 18 deletions
|
@ -221,6 +221,23 @@ PYTHON_TOOL = {
|
|||
}
|
||||
}
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"get_current_weather",
|
||||
"description":"Get the current weather in a given location",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"location":{
|
||||
"type":"string",
|
||||
"description":"The city and state, e.g. San Francisco, CA"
|
||||
}
|
||||
},
|
||||
"required":["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
|
@ -308,22 +325,76 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
|||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
|
||||
(PYTHON_TOOL, None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(PYTHON_TOOL, None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(PYTHON_TOOL, None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
||||
(None, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
||||
# TODO: fix these models
|
||||
(PYTHON_TOOL, '{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (PYTHON_TOOL, None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
])
|
||||
def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
def test_weather_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 2
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 128
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = hf_file
|
||||
if template_override:
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
server.start(timeout_seconds=15*60)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 256,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
],
|
||||
"tools": [WEATHER_TOOL],
|
||||
# "temperature": 0.5,
|
||||
# "top_k": 10,
|
||||
# "top_p": 0.9,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments is not None:
|
||||
assert actual_arguments == expected_arguments
|
||||
else:
|
||||
actual_arguments = json.loads(actual_arguments)
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("expected_arguments,hf_repo,hf_file,template_override", [
|
||||
('{"code":"print("}', "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/functionary-small-v3.2-GGUF", "functionary-small-v3.2-Q8_0.gguf", ("meetkai-functionary-medium-v3.2", None)),
|
||||
(None, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
|
||||
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
|
||||
(None, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||
(None, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF", "Mistral-Nemo-Instruct-2407-Q6_K_L.gguf", None),
|
||||
# TODO: fix these models
|
||||
# (None, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
# (None, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q6_K_L.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
|
||||
])
|
||||
def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = 128
|
||||
|
@ -341,7 +412,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_re
|
|||
{"role": "user", "content": "say hello world with python"},
|
||||
# {"role": "user", "content": "Print a hello world message with python"},
|
||||
],
|
||||
"tools": [tool],
|
||||
"tools": [PYTHON_TOOL],
|
||||
"temperature": 0.5,
|
||||
"top_k": 10,
|
||||
"top_p": 0.9,
|
||||
|
@ -351,10 +422,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: str | None, hf_re
|
|||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
if tool["type"] == "function":
|
||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
elif tool["type"] == "code_interpreter":
|
||||
assert re.match('i?python', tool_call["function"]["name"])
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
if expected_arguments is not None:
|
||||
assert actual_arguments == expected_arguments
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue