From 7b5e0803c84435e8005725d8f9ddceb5d1820443 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 29 Jan 2025 18:16:35 +0000 Subject: [PATCH] Move templates/ under models/ --- .editorconfig | 2 +- common/chat-handler.cpp | 2 - examples/server/tests/unit/test_tool_call.py | 10 ++-- ...reForAI-c4ai-command-r-plus-tool_use.jinja | 0 ...rch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja | 0 ...earch-Hermes-3-Llama-3.1-8B-tool_use.jinja | 0 .../templates/Qwen-Qwen2.5-7B-Instruct.jinja | 0 ...seek-ai-DeepSeek-R1-Distill-Llama-8B.jinja | 0 ...seek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja | 56 +++++++++++++++++++ ...fireworks-ai-llama-3-firefunction-v2.jinja | 0 .../templates/google-gemma-2-2b-it.jinja | 0 .../meetkai-functionary-medium-v3.1.jinja | 0 .../meetkai-functionary-medium-v3.2.jinja | 0 .../meta-llama-Llama-3.1-8B-Instruct.jinja | 0 .../meta-llama-Llama-3.2-3B-Instruct.jinja | 0 .../meta-llama-Llama-3.3-70B-Instruct.jinja | 0 .../microsoft-Phi-3.5-mini-instruct.jinja | 0 ...mistralai-Mistral-Nemo-Instruct-2407.jinja | 0 tests/test-chat-handler.cpp | 26 ++++----- 19 files changed, 75 insertions(+), 21 deletions(-) rename {tests/chat => models}/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja (100%) rename {tests/chat => models}/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja (100%) rename {tests/chat => models}/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja (100%) rename {tests/chat => models}/templates/Qwen-Qwen2.5-7B-Instruct.jinja (100%) rename {tests/chat => models}/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja (100%) create mode 100644 models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja rename {tests/chat => models}/templates/fireworks-ai-llama-3-firefunction-v2.jinja (100%) rename {tests/chat => models}/templates/google-gemma-2-2b-it.jinja (100%) rename {tests/chat => models}/templates/meetkai-functionary-medium-v3.1.jinja (100%) rename {tests/chat => models}/templates/meetkai-functionary-medium-v3.2.jinja (100%) rename {tests/chat => models}/templates/meta-llama-Llama-3.1-8B-Instruct.jinja (100%) rename {tests/chat => models}/templates/meta-llama-Llama-3.2-3B-Instruct.jinja (100%) rename {tests/chat => models}/templates/meta-llama-Llama-3.3-70B-Instruct.jinja (100%) rename {tests/chat => models}/templates/microsoft-Phi-3.5-mini-instruct.jinja (100%) rename {tests/chat => models}/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja (100%) diff --git a/.editorconfig b/.editorconfig index e092729bd..5d63d0a51 100644 --- a/.editorconfig +++ b/.editorconfig @@ -41,7 +41,7 @@ indent_style = tab trim_trailing_whitespace = unset insert_final_newline = unset -[tests/chat/templates/*.jinja] +[models/templates/*.jinja] indent_style = unset indent_size = unset end_of_line = unset diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index fa255d806..069db4bee 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -450,7 +450,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_ data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; - // auto add_tool = [&](const json & tool) { foreach_function(params.tools, [&](const json & tool) { const auto & function = tool["function"]; std::string name = function["name"]; @@ -604,7 +603,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common } static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) { - // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_data data; json tools = params.tools.is_null() ? params.tools : json::array(); diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index d9fcdd8e6..69d0b63bc 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -67,7 +67,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a # server = ServerPreset.stories15m_moe() server.jinja = True server.n_predict = n_predict - server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' server.start() res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, @@ -166,7 +166,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str server.model_hf_file = hf_file if template_override: (template_hf_repo, template_variant) = template_override - server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -200,7 +200,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too global server server.jinja = True server.n_predict = n_predict - server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' server.start() res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predict, @@ -270,7 +270,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[ server.model_hf_file = hf_file if template_override: (template_hf_repo, template_variant) = template_override - server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start(timeout_seconds=15*60) res = server.make_request("POST", "/chat/completions", data={ @@ -319,7 +319,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_ server.model_hf_file = hf_file if template_override: (template_hf_repo, template_variant) = template_override - server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja" assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." server.start(timeout_seconds=15*60) res = server.make_request("POST", "/chat/completions", data={ diff --git a/tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja b/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja similarity index 100% rename from tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja rename to models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja diff --git a/tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja b/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja similarity index 100% rename from tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja rename to models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja diff --git a/tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja b/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja similarity index 100% rename from tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja rename to models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja diff --git a/tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja b/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja similarity index 100% rename from tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja rename to models/templates/Qwen-Qwen2.5-7B-Instruct.jinja diff --git a/tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja similarity index 100% rename from tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja rename to models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja new file mode 100644 index 000000000..2ebfe7c1e --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja @@ -0,0 +1,56 @@ +{% if not add_generation_prompt is defined %} +{% set add_generation_prompt = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %} +{%- for message in messages %} +{%- if message['role'] == 'system' %} +{% set ns.system_prompt = message['content'] %} +{%- endif %} +{%- endfor %} +{{bos_token}} +{{ns.system_prompt}} +{%- for message in messages %} +{%- if message['role'] == 'user' %} +{%- set ns.is_tool = false -%} +{{'<|User|>' + message['content']}} +{%- endif %} +{%- if message['role'] == 'assistant' and message['content'] is none %} +{%- set ns.is_tool = false -%} +{%- for tool in message['tool_calls']%} +{%- if not ns.is_first %} +{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} +{%- set ns.is_first = true -%} +{%- else %} +{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} +{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} +{%- endif %} +{%- endfor %} +{%- endif %} +{%- if message['role'] == 'assistant' and message['content'] is not none %} +{%- if ns.is_tool %} +{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} +{%- set ns.is_tool = false -%} +{%- else %} +{% set content = message['content'] %} +{% if '' in content %} +{% set content = content.split('')[-1] %} +{% endif %} +{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} +{%- endif %} +{%- endif %} +{%- if message['role'] == 'tool' %} +{%- set ns.is_tool = true -%} +{%- if ns.is_output_first %} +{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} +{%- set ns.is_output_first = false %} +{%- else %} +{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} +{%- endif %} +{%- endif %} +{%- endfor -%} +{% if ns.is_tool %} +{{'<|tool▁outputs▁end|>'}} +{% endif %} +{% if add_generation_prompt and not ns.is_tool %} +{{'<|Assistant|>'}} +{% endif %} \ No newline at end of file diff --git a/tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja similarity index 100% rename from tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja rename to models/templates/fireworks-ai-llama-3-firefunction-v2.jinja diff --git a/tests/chat/templates/google-gemma-2-2b-it.jinja b/models/templates/google-gemma-2-2b-it.jinja similarity index 100% rename from tests/chat/templates/google-gemma-2-2b-it.jinja rename to models/templates/google-gemma-2-2b-it.jinja diff --git a/tests/chat/templates/meetkai-functionary-medium-v3.1.jinja b/models/templates/meetkai-functionary-medium-v3.1.jinja similarity index 100% rename from tests/chat/templates/meetkai-functionary-medium-v3.1.jinja rename to models/templates/meetkai-functionary-medium-v3.1.jinja diff --git a/tests/chat/templates/meetkai-functionary-medium-v3.2.jinja b/models/templates/meetkai-functionary-medium-v3.2.jinja similarity index 100% rename from tests/chat/templates/meetkai-functionary-medium-v3.2.jinja rename to models/templates/meetkai-functionary-medium-v3.2.jinja diff --git a/tests/chat/templates/meta-llama-Llama-3.1-8B-Instruct.jinja b/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja similarity index 100% rename from tests/chat/templates/meta-llama-Llama-3.1-8B-Instruct.jinja rename to models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja diff --git a/tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja b/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja similarity index 100% rename from tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja rename to models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja diff --git a/tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja b/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja similarity index 100% rename from tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja rename to models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja diff --git a/tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja b/models/templates/microsoft-Phi-3.5-mini-instruct.jinja similarity index 100% rename from tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja rename to models/templates/microsoft-Phi-3.5-mini-instruct.jinja diff --git a/tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja b/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja similarity index 100% rename from tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja rename to models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja diff --git a/tests/test-chat-handler.cpp b/tests/test-chat-handler.cpp index 6ea595c3a..1beb2fa5c 100644 --- a/tests/test-chat-handler.cpp +++ b/tests/test-chat-handler.cpp @@ -314,12 +314,12 @@ static void test_template_output_parsers() { }; { - const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", ""); std::vector end_tokens { "" }; assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), tools_params)); // Generic tool calls doesn't generate / parse content-only messages symmetrically. assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser( @@ -340,7 +340,7 @@ static void test_template_output_parsers() { "}"); } { - const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); std::vector end_tokens { "" }; assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params)); @@ -351,12 +351,12 @@ static void test_template_output_parsers() { /* skip_grammar_test= */ true); } { - const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), tools_params)); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, @@ -369,11 +369,11 @@ static void test_template_output_parsers() { ""); } { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, @@ -384,7 +384,7 @@ static void test_template_output_parsers() { "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); @@ -395,7 +395,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params)); @@ -406,7 +406,7 @@ static void test_template_output_parsers() { "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params)); @@ -420,7 +420,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); std::vector end_tokens { "<|eot_id|>" }; assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params)); @@ -431,7 +431,7 @@ static void test_template_output_parsers() { " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { - const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); + const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); std::vector end_tokens { "<|end▁of▁sentence|>" }; assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));