tools: greedy sampling in tests

This commit is contained in:
ochafik 2025-01-18 02:39:10 +00:00
parent 2ceabee0f8
commit 259d9e4511

View file

@ -298,20 +298,20 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [ @pytest.mark.parametrize("tool,expected_arguments,hf_repo,hf_file,template_override", [
(PYTHON_TOOL, {"code": "print('Hello World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print(\"Hello World!\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print(\"Hello World!\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello World')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, {"code": "print('Hello, world!')"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, {"code": "print('Hello, World!'}"}, "bartowski/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, {"code": "print("}, "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (PYTHON_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World!')"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print(\"Hello World!\")"}, "bartowski/Qwen2.5-7B-Instruct-GGUF", "Qwen2.5-7B-Instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello World')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "bartowski/Phi-3.5-mini-instruct-GGUF", "Phi-3.5-mini-instruct-Q4_K_M.gguf", None),
(CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, world!')"}, "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf", ("NousResearch-Hermes-2-Pro-Llama-3-8B", "tool_use")),
(CODE_INTEPRETER_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), (CODE_INTEPRETER_TOOL, {"code": "print('hello world')"}, "NousResearch/Hermes-3-Llama-3.1-8B-GGUF", "Hermes-3-Llama-3.1-8B.Q4_K_M.gguf", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(CODE_INTEPRETER_TOOL, {"code": "print('hello world')"}, "lmstudio-community/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (CODE_INTEPRETER_TOOL, {"code": "print('Hello, World!')"}, "lmstudio-community/Llama-3.2-1B-Instruct-GGUF", "Llama-3.2-1B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)), (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None), (CODE_INTEPRETER_TOOL, {"code": "print("}, "lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", None),
# TODO: fix tool call handling of these models # TODO: fix tool call handling of these models
@ -331,8 +331,6 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
(template_hf_repo, template_variant) = template_override (template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_hf_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
# else:
# server.chat_template_file = None
server.start(timeout_seconds=15*60) server.start(timeout_seconds=15*60)
res = server.make_request("POST", "/chat/completions", data={ res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 256, "max_tokens": 256,
@ -341,6 +339,10 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
{"role": "user", "content": "say hello world with python"}, {"role": "user", "content": "say hello world with python"},
], ],
"tools": [tool], "tools": [tool],
# Greedy sampling
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = res.body["choices"][0]