From 1e9acd2d312a6b3f9f005915ebecaedc97da2edb Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 3 Feb 2025 04:07:11 +0000 Subject: [PATCH] tool-call: allow `--jinja --chat-template chatml` --- common/common.cpp | 21 +++-- examples/server/tests/unit/test_tool_call.py | 96 ++++++++++++++++---- 2 files changed, 91 insertions(+), 26 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6c81d18f9..b9d1e0e30 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1869,11 +1869,19 @@ std::string common_chat_format_example(const common_chat_template & tmpl, bool u return common_chat_apply_template(tmpl, msgs, true, use_jinja); } +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%})" + common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); - std::string default_template_src = chat_template_override; - std::string template_tool_use_src = chat_template_override; + std::string default_template_src = chat_template_override == "chatml" ? CHATML_TEMPLATE_SRC : chat_template_override; + std::string template_tool_use_src = chat_template_override == "chatml" ? CHATML_TEMPLATE_SRC : ""; bool has_explicit_template = !chat_template_override.empty(); if (chat_template_override.empty()) { auto str = llama_model_chat_template(model, /* name */ nullptr); @@ -1891,14 +1899,7 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model if (!template_tool_use_src.empty()) { default_template_src = template_tool_use_src; } else { - default_template_src = R"( - {%- for message in messages -%} - {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} - {%- endfor -%} - {%- if add_generation_prompt -%} - {{- "<|im_start|>assistant\n" -}} - {%- endif -%} - )"; + default_template_src = CHATML_TEMPLATE_SRC; } } const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index e6ed9c9be..9c6e1b856 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -67,8 +67,8 @@ WEATHER_TOOL = { def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): - n_predict = 512 global server + n_predict = 512 # server = ServerPreset.stories15m_moe() server.jinja = True server.n_predict = n_predict @@ -139,29 +139,49 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), - (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), # TODO: fix these # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): + global server n_predict = 512 server.n_slots = 1 server.jinja = True @@ -169,10 +189,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str server.n_predict = n_predict server.model_hf_repo = hf_repo server.model_hf_file = None - if template_override: + if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, @@ -252,18 +274,36 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t @pytest.mark.slow @pytest.mark.parametrize("hf_repo,template_override", [ ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server server.n_slots = 1 server.jinja = True @@ -271,10 +311,12 @@ def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | Non server.n_predict = 512 server.model_hf_repo = hf_repo server.model_hf_file = None - if template_override: + if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": 256, @@ -298,19 +340,39 @@ def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | Non @pytest.mark.slow @pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ - (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), + (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + + ('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + + (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + + (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) -def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): +def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server server.n_slots = 1 server.jinja = True @@ -318,10 +380,12 @@ def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: server.n_predict = 128 server.model_hf_repo = hf_repo server.model_hf_file = None - if template_override: + if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + elif isinstance(template_override, str): + server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) res = server.make_request("POST", "/chat/completions", data={ "max_tokens": 256,