Move templates/ under models/
This commit is contained in:
parent
682026f84b
commit
7b5e0803c8
19 changed files with 75 additions and 21 deletions
|
@ -41,7 +41,7 @@ indent_style = tab
|
||||||
trim_trailing_whitespace = unset
|
trim_trailing_whitespace = unset
|
||||||
insert_final_newline = unset
|
insert_final_newline = unset
|
||||||
|
|
||||||
[tests/chat/templates/*.jinja]
|
[models/templates/*.jinja]
|
||||||
indent_style = unset
|
indent_style = unset
|
||||||
indent_size = unset
|
indent_size = unset
|
||||||
end_of_line = unset
|
end_of_line = unset
|
||||||
|
|
|
@ -450,7 +450,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
// auto add_tool = [&](const json & tool) {
|
|
||||||
foreach_function(params.tools, [&](const json & tool) {
|
foreach_function(params.tools, [&](const json & tool) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
std::string name = function["name"];
|
std::string name = function["name"];
|
||||||
|
@ -604,7 +603,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
|
||||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
json tools = params.tools.is_null() ? params.tools : json::array();
|
json tools = params.tools.is_null() ? params.tools : json::array();
|
||||||
|
|
|
@ -67,7 +67,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
|
||||||
# server = ServerPreset.stories15m_moe()
|
# server = ServerPreset.stories15m_moe()
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
|
@ -166,7 +166,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||||
server.model_hf_file = hf_file
|
server.model_hf_file = hf_file
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
@ -200,7 +200,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
|
||||||
global server
|
global server
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
|
@ -270,7 +270,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
|
||||||
server.model_hf_file = hf_file
|
server.model_hf_file = hf_file
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start(timeout_seconds=15*60)
|
server.start(timeout_seconds=15*60)
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
@ -319,7 +319,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
|
||||||
server.model_hf_file = hf_file
|
server.model_hf_file = hf_file
|
||||||
if template_override:
|
if template_override:
|
||||||
(template_hf_repo, template_variant) = template_override
|
(template_hf_repo, template_variant) = template_override
|
||||||
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
server.start(timeout_seconds=15*60)
|
server.start(timeout_seconds=15*60)
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
{% if not add_generation_prompt is defined %}
|
||||||
|
{% set add_generation_prompt = false %}
|
||||||
|
{% endif %}
|
||||||
|
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message['role'] == 'system' %}
|
||||||
|
{% set ns.system_prompt = message['content'] %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{bos_token}}
|
||||||
|
{{ns.system_prompt}}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message['role'] == 'user' %}
|
||||||
|
{%- set ns.is_tool = false -%}
|
||||||
|
{{'<|User|>' + message['content']}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message['role'] == 'assistant' and message['content'] is none %}
|
||||||
|
{%- set ns.is_tool = false -%}
|
||||||
|
{%- for tool in message['tool_calls']%}
|
||||||
|
{%- if not ns.is_first %}
|
||||||
|
{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||||
|
{%- set ns.is_first = true -%}
|
||||||
|
{%- else %}
|
||||||
|
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}
|
||||||
|
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message['role'] == 'assistant' and message['content'] is not none %}
|
||||||
|
{%- if ns.is_tool %}
|
||||||
|
{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
|
||||||
|
{%- set ns.is_tool = false -%}
|
||||||
|
{%- else %}
|
||||||
|
{% set content = message['content'] %}
|
||||||
|
{% if '</think>' in content %}
|
||||||
|
{% set content = content.split('</think>')[-1] %}
|
||||||
|
{% endif %}
|
||||||
|
{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message['role'] == 'tool' %}
|
||||||
|
{%- set ns.is_tool = true -%}
|
||||||
|
{%- if ns.is_output_first %}
|
||||||
|
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||||
|
{%- set ns.is_output_first = false %}
|
||||||
|
{%- else %}
|
||||||
|
{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor -%}
|
||||||
|
{% if ns.is_tool %}
|
||||||
|
{{'<|tool▁outputs▁end|>'}}
|
||||||
|
{% endif %}
|
||||||
|
{% if add_generation_prompt and not ns.is_tool %}
|
||||||
|
{{'<|Assistant|>'}}
|
||||||
|
{% endif %}
|
|
@ -314,12 +314,12 @@ static void test_template_output_parsers() {
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
std::vector<std::string> end_tokens { "<end_of_turn>" };
|
||||||
|
|
||||||
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
|
||||||
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
|
||||||
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
|
||||||
|
@ -340,7 +340,7 @@ static void test_template_output_parsers() {
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "</s>" };
|
std::vector<std::string> end_tokens { "</s>" };
|
||||||
|
|
||||||
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -351,12 +351,12 @@ static void test_template_output_parsers() {
|
||||||
/* skip_grammar_test= */ true);
|
/* skip_grammar_test= */ true);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|im_end|>" };
|
std::vector<std::string> end_tokens { "<|im_end|>" };
|
||||||
|
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
|
||||||
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, tool_call_message, tools,
|
test_template(tmpl, end_tokens, tool_call_message, tools,
|
||||||
|
@ -369,11 +369,11 @@ static void test_template_output_parsers() {
|
||||||
"</tool_call>");
|
"</tool_call>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
|
||||||
|
|
||||||
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
|
||||||
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
|
||||||
|
@ -384,7 +384,7 @@ static void test_template_output_parsers() {
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -395,7 +395,7 @@ static void test_template_output_parsers() {
|
||||||
"<function=special_function>{\"arg1\": 1}</function>");
|
"<function=special_function>{\"arg1\": 1}</function>");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -406,7 +406,7 @@ static void test_template_output_parsers() {
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
|
||||||
|
@ -420,7 +420,7 @@ static void test_template_output_parsers() {
|
||||||
"{\"arg1\": 1}");
|
"{\"arg1\": 1}");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
std::vector<std::string> end_tokens { "<|eot_id|>" };
|
||||||
|
|
||||||
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
|
||||||
|
@ -431,7 +431,7 @@ static void test_template_output_parsers() {
|
||||||
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
|
||||||
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
std::vector<std::string> end_tokens { "<|end▁of▁sentence|>" };
|
||||||
|
|
||||||
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue