diff --git a/common/common.cpp b/common/common.cpp index 72b6491f1..6c81d18f9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { auto chat_template = common_chat_template(tmpl, "", ""); - common_chat_inputs params; - params.messages = json::array({{ + common_chat_inputs inputs; + inputs.messages = json::array({{ {"role", "user"}, {"content", "test"}, }}); - common_chat_params_init(chat_template, params); + common_chat_params_init(chat_template, inputs); return true; } catch (const std::exception & e) { LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); @@ -1803,11 +1803,10 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - common_chat_inputs params; - params.messages = messages; - params.add_generation_prompt = add_ass; - auto data = common_chat_params_init(tmpl, params); - return data.prompt; + common_chat_inputs inputs; + inputs.messages = messages; + inputs.add_generation_prompt = add_ass; + return common_chat_params_init(tmpl, inputs).prompt; } int alloc_size = 0; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1fe79c244..8e39ddfc8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1824,16 +1824,16 @@ struct server_context { if (use_jinja) { auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs params; - params.messages = json::array({{ + common_chat_inputs inputs; + inputs.messages = json::array({{ {"role", "user"}, {"content", "test"}, }}); GGML_ASSERT(templates.template_default); try { - common_chat_params_init(*templates.template_default, params); + common_chat_params_init(*templates.template_default, inputs); if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, params); + common_chat_params_init(*templates.template_tool_use, inputs); } return true; } catch (const std::exception & e) { @@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) { std::vector tasks; try { - common_chat_params chat_data; + common_chat_params chat_params; bool add_special = false; if (tmpl && ctx_server.params_base.use_jinja) { - chat_data = common_chat_params_init(*tmpl, { + chat_params = common_chat_params_init(*tmpl, { /* .messages = */ json_value(data, "messages", json::array()), /* .tools = */ json_value(data, "tools", json()), /* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")), @@ -3799,28 +3799,28 @@ int main(int argc, char ** argv) { /* .stream = */ json_value(data, "stream", false), /* .grammar = */ json_value(data, "grammar", std::string("")), }); - LOG_INF("Chat format: %s\n", chat_data.format.c_str()); - LOG_DBG("Prompt: %s\n", chat_data.prompt.get().c_str()); - LOG_DBG("Grammar: %s\n", chat_data.grammar.c_str()); + LOG_INF("Chat format: %s\n", chat_params.format.c_str()); + LOG_DBG("Prompt: %s\n", chat_params.prompt.get().c_str()); + LOG_DBG("Grammar: %s\n", chat_params.grammar.c_str()); if (data.contains("grammar")) { - if (!chat_data.grammar.empty()) { + if (!chat_params.grammar.empty()) { throw std::runtime_error("Cannot provide grammar and tools"); } - chat_data.grammar = data.at("grammar"); + chat_params.grammar = data.at("grammar"); } // TODO: move inside minja:chat_template? add_special = tmpl->source().find("eos_token") == std::string::npos && tmpl->source().find("bos_token") == std::string::npos; } else { add_special = true; - chat_data.prompt = data.at("prompt"); + chat_params.prompt = data.at("prompt"); if (data.contains("grammar")) { - chat_data.grammar = data.at("grammar"); + chat_params.grammar = data.at("grammar"); } else if (data.contains("json_schema")) { - chat_data.grammar = json_schema_to_grammar(data.at("json_schema")); + chat_params.grammar = json_schema_to_grammar(data.at("json_schema")); } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_params.prompt, add_special, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type); @@ -3838,9 +3838,9 @@ int main(int argc, char ** argv) { // OAI-compat task.params.oaicompat = oaicompat; task.params.oaicompat_cmpl_id = completion_id; - task.params.sampling.grammar = chat_data.grammar; - task.params.sampling.grammar_lazy = chat_data.grammar_lazy; - for (const auto & trigger : chat_data.grammar_triggers) { + task.params.sampling.grammar = chat_params.grammar; + task.params.sampling.grammar_lazy = chat_params.grammar_lazy; + for (const auto & trigger : chat_params.grammar_triggers) { auto ids = common_tokenize(ctx_server.vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); @@ -3850,8 +3850,8 @@ int main(int argc, char ** argv) { LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); task.params.sampling.grammar_trigger_words.push_back(trigger); } - task.params.antiprompt = chat_data.additional_stops; - task.params.chat_parser = chat_data.parser; + task.params.antiprompt = chat_params.additional_stops; + task.params.chat_parser = chat_params.parser; if (task.params.sampling.grammar_lazy) { GGML_ASSERT(task.params.sampling.grammar_trigger_tokens.size() > 0 || task.params.sampling.grammar_trigger_words.size() > 0); } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 74f6bd1be..77ae529de 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -169,18 +169,18 @@ struct delta_data { }; static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { - common_chat_inputs params; - params.parallel_tool_calls = true; - params.messages = json::array(); - params.messages.push_back(user_message); - params.tools = tools; - auto prefix_data = common_chat_params_init(tmpl, params); - params.messages.push_back(delta_message); - params.add_generation_prompt = false; - auto full_data = common_chat_params_init(tmpl, params); + common_chat_inputs inputs; + inputs.parallel_tool_calls = true; + inputs.messages = json::array(); + inputs.messages.push_back(user_message); + inputs.tools = tools; + auto params_prefix = common_chat_params_init(tmpl, inputs); + inputs.messages.push_back(delta_message); + inputs.add_generation_prompt = false; + auto params_full = common_chat_params_init(tmpl, inputs); - std::string prefix = prefix_data.prompt; - std::string full = full_data.prompt; + std::string prefix = params_prefix.prompt; + std::string full = params_full.prompt; // Check full starts with prefix if (full.find(prefix) != 0) { @@ -203,7 +203,7 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto break; } } - return {delta, full_data.grammar, full_data.parser}; + return {delta, params_full.grammar, params_full.parser}; } /* @@ -220,12 +220,6 @@ static void test_template(const common_chat_template & tmpl, const std::vector", ""); + const common_chat_template tmpl(read_file( + "models/templates/google-gemma-2-2b-it.jinja"), "", ""); std::vector end_tokens { "" }; assert_equals(std::string("content-only"), describe(tmpl, no_tools_params)); assert_equals(std::string("generic tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file( + "models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "", ""), tools_params)); // Generic tool calls doesn't generate / parse content-only messages symmetrically. assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser( @@ -340,7 +335,8 @@ static void test_template_output_parsers() { "}"); } { - const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); std::vector end_tokens { "" }; assert_equals(std::string("mistral nemo tool calls"), describe(tmpl, tools_params)); @@ -351,12 +347,15 @@ static void test_template_output_parsers() { /* skip_grammar_test= */ true); } { - const common_chat_template tmpl(read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); std::vector end_tokens { "<|im_end|>" }; assert_equals(std::string("hermes 2 pro tool calls"), describe(tmpl, tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), tools_params)); - assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file( + "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), "", ""), tools_params)); + assert_equals(std::string("hermes 2 pro tool calls"), describe(common_chat_template(read_file( + "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), tools_params)); test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, tool_call_message, tools, @@ -369,11 +368,13 @@ static void test_template_output_parsers() { ""); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(tmpl, tools_params)); - assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); + assert_equals(std::string("llama 3.x tool calls (w/ builtin tools)"), describe(common_chat_template(read_file( + "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""), tools_params)); // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true); test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools, @@ -384,7 +385,8 @@ static void test_template_output_parsers() { "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("llama 3.x tool calls"), describe(tmpl, tools_params)); @@ -395,7 +397,8 @@ static void test_template_output_parsers() { "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/meetkai-functionary-medium-v3.1.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.1 llama 3.1 tool calls"), describe(tmpl, tools_params)); @@ -406,7 +409,8 @@ static void test_template_output_parsers() { "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/meetkai-functionary-medium-v3.2.jinja"), "", ""); std::vector end_tokens { "<|eom_id|>", "<|eot_id|>" }; assert_equals(std::string("functionary v3.2 content-only"), describe(tmpl, no_tools_params)); @@ -420,7 +424,8 @@ static void test_template_output_parsers() { "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); std::vector end_tokens { "<|eot_id|>" }; assert_equals(std::string("firefunction v2 tool calls"), describe(tmpl, tools_params)); @@ -431,7 +436,8 @@ static void test_template_output_parsers() { " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { - const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); + const common_chat_template tmpl(read_file( + "models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), "", ""); std::vector end_tokens { "<|end▁of▁sentence|>" }; assert_equals(std::string("deepseek r1 tool calls"), describe(tmpl, tools_params));