Move templates/ under models/

This commit is contained in:
ochafik 2025-01-29 18:16:35 +00:00
parent 682026f84b
commit 7b5e0803c8
19 changed files with 75 additions and 21 deletions

View file

@ -41,7 +41,7 @@ indent_style = tab
trim_trailing_whitespace = unset
insert_final_newline = unset
[tests/chat/templates/*.jinja]
[models/templates/*.jinja]
indent_style = unset
indent_size = unset
end_of_line = unset

View file

@ -450,7 +450,6 @@ static common_chat_data common_chat_init_llama_3_2_tool_calls(const common_chat_
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
// auto add_tool = [&](const json & tool) {
foreach_function(params.tools, [&](const json & tool) {
const auto & function = tool["function"];
std::string name = function["name"];
@ -604,7 +603,6 @@ static common_chat_data common_chat_init_functionary_v3_2_tool_call(const common
}
static common_chat_data common_chat_init_functionary_v3_1_llama_3_1_tool_call(const common_chat_template & tmpl, const struct common_chat_params & params) {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_data data;
json tools = params.tools.is_null() ? params.tools : json::array();

View file

@ -67,7 +67,7 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a
# server = ServerPreset.stories15m_moe()
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict,
@ -166,7 +166,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
server.model_hf_file = hf_file
if template_override:
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
server.start()
res = server.make_request("POST", "/chat/completions", data={
@ -200,7 +200,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too
global server
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict,
@ -270,7 +270,7 @@ def test_weather_tool_call(hf_repo: str, hf_file: str, template_override: Tuple[
server.model_hf_file = hf_file
if template_override:
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
server.start(timeout_seconds=15*60)
res = server.make_request("POST", "/chat/completions", data={
@ -319,7 +319,7 @@ def test_hello_world_tool_call(expected_arguments: str | None, hf_repo: str, hf_
server.model_hf_file = hf_file
if template_override:
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../tests/chat/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
server.start(timeout_seconds=15*60)
res = server.make_request("POST", "/chat/completions", data={

View file

@ -0,0 +1,56 @@
{% if not add_generation_prompt is defined %}
{% set add_generation_prompt = false %}
{% endif %}
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{% set ns.system_prompt = message['content'] %}
{%- endif %}
{%- endfor %}
{{bos_token}}
{{ns.system_prompt}}
{%- for message in messages %}
{%- if message['role'] == 'user' %}
{%- set ns.is_tool = false -%}
{{'<User>' + message['content']}}
{%- endif %}
{%- if message['role'] == 'assistant' and message['content'] is none %}
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls']%}
{%- if not ns.is_first %}
{{'<Assistant><tool▁calls▁begin><tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}
{%- set ns.is_first = true -%}
{%- else %}
{{'\n' + '<tool▁call▁begin>' + tool['type'] + '<tool▁sep>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<tool▁call▁end>'}}
{{'<tool▁calls▁end><end▁of▁sentence>'}}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- if message['role'] == 'assistant' and message['content'] is not none %}
{%- if ns.is_tool %}
{{'<tool▁outputs▁end>' + message['content'] + '<end▁of▁sentence>'}}
{%- set ns.is_tool = false -%}
{%- else %}
{% set content = message['content'] %}
{% if '</think>' in content %}
{% set content = content.split('</think>')[-1] %}
{% endif %}
{{'<Assistant>' + content + '<end▁of▁sentence>'}}
{%- endif %}
{%- endif %}
{%- if message['role'] == 'tool' %}
{%- set ns.is_tool = true -%}
{%- if ns.is_output_first %}
{{'<tool▁outputs▁begin><tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}
{%- set ns.is_output_first = false %}
{%- else %}
{{'\n<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}
{%- endif %}
{%- endif %}
{%- endfor -%}
{% if ns.is_tool %}
{{'<tool▁outputs▁end>'}}
{% endif %}
{% if add_generation_prompt and not ns.is_tool %}
{{'<Assistant>'}}
{% endif %}

View file

@ -314,12 +314,12 @@ static void test_template_output_parsers() {
};
{
const common_chat_template tmpl(read_file("tests/chat/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end_of_turn>" };
assert_equals(std::string("content-only"), describe(tmpl, no_tools_params));
assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("tests/chat/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
@ -340,7 +340,7 @@ static void test_template_output_parsers() {
"}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "</s>" };
assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params));
@ -351,12 +351,12 @@ static void test_template_output_parsers() {
/* skip_grammar_test= */ true);
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|im_end|>" };
assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "<s>", "</s>"), tools_params));
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, tool_call_message, tools,
@ -369,11 +369,11 @@ static void test_template_output_parsers() {
"</tool_call>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.1 tool calls"), describe(tmpl, tools_params));
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("tests/chat/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
assert_equals(std::string("llama 3.1 tool calls"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "<s>", "</s>"), tools_params));
// test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
@ -384,7 +384,7 @@ static void test_template_output_parsers() {
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params));
@ -395,7 +395,7 @@ static void test_template_output_parsers() {
"<function=special_function>{\"arg1\": 1}</function>");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("llama 3.2 tool calls"), describe(tmpl, tools_params));
@ -406,7 +406,7 @@ static void test_template_output_parsers() {
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eom_id|>", "<|eot_id|>" };
assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params));
@ -420,7 +420,7 @@ static void test_template_output_parsers() {
"{\"arg1\": 1}");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<|eot_id|>" };
assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params));
@ -431,7 +431,7 @@ static void test_template_output_parsers() {
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
{
const common_chat_template tmpl(read_file("tests/chat/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "<s>", "</s>");
std::vector<std::string> end_tokens { "<end▁of▁sentence>" };
assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));