Move templates/ under models/
This commit is contained in:
parent
682026f84b
commit
7b5e0803c8
19 changed files with 75 additions and 21 deletions
|
@ -41,7 +41,7 @@ indent_style = tab
|
|||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[tests/chat/templates/*.jinja]
|
||||
[models/templates/*.jinja]
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
end_of_line = unset
|
||||
|
|
|
@ -450,7 +450,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
|||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
// auto add_tool = [&](const json & tool) {
|
||||
foreach_function(params.tools, [&](const json & tool) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
|
@ -604,7 +603,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
|||
}
|
||||
|
||||
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_data data;
|
||||
json tools = params.tools.is_null() ? params.tools : json::array();
|
||||
|
|
|
@ -67,7 +67,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
|
|||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
|
@ -166,7 +166,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||
server.model_hf_file = hf_file
|
||||
if template_override:
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
@ -200,7 +200,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
|
|||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
|
@ -270,7 +270,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
|||
server.model_hf_file = hf_file
|
||||
if template_override:
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
server.start(timeout_seconds=15*60)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
@ -319,7 +319,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
|||
server.model_hf_file = hf_file
|
||||
if template_override:
|
||||
(template_hf_repo, template_variant) = template_override
|
||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||
server.start(timeout_seconds=15*60)
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
{% if not add_generation_prompt is defined %}
|
||||
{% set add_generation_prompt = false %}
|
||||
{% endif %}
|
||||
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{% set ns.system_prompt = message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{bos_token}}
|
||||
{{ns.system_prompt}}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{{'<|User|>' + message['content']}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is none %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls']%}
|
||||
{%- if not ns.is_first %}
|
||||
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['content'] is not none %}
|
||||
{%- if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{% set content = message['content'] %}
|
||||
{% if '</think>' in content %}
|
||||
{% set content = content.split('</think>')[-1] %}
|
||||
{% endif %}
|
||||
{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{%- if ns.is_output_first %}
|
||||
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- set ns.is_output_first = false %}
|
||||
{%- else %}
|
||||
{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{% if ns.is_tool %}
|
||||
{{'<|tool▁outputs▁end|>'}}
|
||||
{% endif %}
|
||||
{% if add_generation_prompt and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{% endif %}
|
|
@ -314,12 +314,12 @@ static void test_template_output_parsers() {
|
|||
};
|
||||
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||
|
||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||
|
@ -340,7 +340,7 @@ static void test_template_output_parsers() {
|
|||
"}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "</s>" };
|
||||
|
||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -351,12 +351,12 @@ static void test_template_output_parsers() {
|
|||
/* skip_grammar_test= */ true);
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||
|
@ -369,11 +369,11 @@ static void test_template_output_parsers() {
|
|||
"</tool_call>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||
|
||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||
|
@ -384,7 +384,7 @@ static void test_template_output_parsers() {
|
|||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -395,7 +395,7 @@ static void test_template_output_parsers() {
|
|||
"<function=special_function>{\"arg1\": 1}</function>");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -406,7 +406,7 @@ static void test_template_output_parsers() {
|
|||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||
|
@ -420,7 +420,7 @@ static void test_template_output_parsers() {
|
|||
"{\"arg1\": 1}");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||
|
||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||
|
@ -431,7 +431,7 @@ static void test_template_output_parsers() {
|
|||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||
}
|
||||
{
|
||||
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||
|
||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue