From 183029d88f4ab8c3d0747125e5fdf99488ff3254 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 31 Jan 2025 14:37:23 -0400 Subject: [PATCH 1/7] Add tools option to llama-cli --- common/arg.cpp | 9 +++++++++ common/common.cpp | 27 ++++++++++++++++++++++----- common/common.h | 8 +++++--- examples/main/main.cpp | 19 +++++++++++++------ 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index f5e9b294f..b916dcfcc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1993,6 +1993,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"--tools"}, "JINJA_TOOLS", + string_format( + "set to a JSON array of tool definitions used for assistant function-calling " + "(requires --jinja)"), + [](common_params ¶ms, const std::string & value) { + params.jinja_tools = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index 6c81d18f9..de4a52905 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1797,13 +1797,27 @@ std::string common_chat_apply_template( const common_chat_template & tmpl, const std::vector & msgs, bool add_ass, - bool use_jinja) { + bool use_jinja, + std::string tools_json_arr) +{ if (use_jinja) { + common_chat_inputs inputs; + auto messages = json::array(); for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - common_chat_inputs inputs; + + if (! tools_json_arr.empty()) { + try { + inputs.tools = tools_json_arr; + + } catch (const json::exception & err) { + LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n", + tools_json_arr.c_str(), err.what()); + } + } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; return common_chat_params_init(tmpl, inputs).prompt; @@ -1843,9 +1857,13 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja) { + bool use_jinja, + std::string tools_json_arr) +{ std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); + auto fmt_past_msg = past_msg.empty() ? "" + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr); + std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -2182,4 +2200,3 @@ common_control_vector_data common_control_vector_load(const std::vector api_keys; std::string ssl_file_key = ""; // NOLINT @@ -645,7 +645,8 @@ std::string common_chat_apply_template( const common_chat_template & tmpl, const std::vector & chat, bool add_ass, - bool use_jinja); + bool use_jinja, + std::string tools_json_arr = std::string()); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -653,7 +654,8 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja); + bool use_jinja, + std::string tools_json_arr = std::string()); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d3542..dd0f65bfe 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,20 +263,27 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, + const std::string & content, + const std::string & tools = std::string()) { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); + auto formatted = common_chat_format_single(*chat_templates.template_default, + chat_msgs, new_msg, + role == "user", + g_params->use_jinja, tools); + chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; { - auto prompt = (params.conversation_mode && params.enable_chat_template) - // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) - // otherwise use the prompt as is + std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt); + bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template); + auto prompt = use_conversation_prompt ? + chat_add_and_format("system", system_prompt, params.jinja_tools) : params.prompt; + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); embd_inp = common_tokenize(ctx, prompt, true, true); From 4ad82587eeccd72206845589c2056993f8b8ad51 Mon Sep 17 00:00:00 2001 From: Mason M Date: Fri, 31 Jan 2025 17:57:00 -0400 Subject: [PATCH 2/7] tools_json_arr now properly passed to apply-template --- common/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index de4a52905..e546f78c2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1871,7 +1871,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); From becf9b4003809ba8e9e8d8b1f2daab665e218953 Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 11:40:29 -0400 Subject: [PATCH 3/7] add tool-choice parameter --- common/arg.cpp | 20 +++++++++---- common/common.cpp | 65 ++++++++++++++++++++++++++++++++---------- common/common.h | 35 +++++++++++++++++++++-- examples/main/main.cpp | 14 ++++----- 4 files changed, 103 insertions(+), 31 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b916dcfcc..dd0cc51d9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1993,15 +1993,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( {"--tools"}, "JINJA_TOOLS", - string_format( - "set to a JSON array of tool definitions used for assistant function-calling " - "(requires --jinja)"), + "set to JSON array of tool definitions used for assistant function-calling (requires --jinja)", [](common_params ¶ms, const std::string & value) { - params.jinja_tools = value; - } - ).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS")); + params.jinja_tools.tools(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + + add_opt(common_arg( + {"--tool-choice"}, "JINJA_TOOL_CHOICE", + "set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")", + [](common_params ¶ms, const std::string & value) { + params.jinja_tools.choice(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index e546f78c2..d9657bc85 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,9 +7,6 @@ #include "common.h" #include "log.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" #include "chat.hpp" @@ -1772,6 +1769,42 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Chat template utils // +common_params_tools::common_params_tools(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +void common_params_tools::tools(std::string tools) { + try { + tools_ = std::make_shared(json::parse(tools)); + if (! tools_->is_array()) { + throw std::invalid_argument("tools must be a valid JSON array"); + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void common_params_tools::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + auto choice_ptr = std::make_shared(json::parse(choice)); + tool_choice_ = choice_ptr; + if (! choice_ptr->is_object()) { + throw std::invalid_argument( + "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -1798,7 +1831,7 @@ std::string common_chat_apply_template( const std::vector & msgs, bool add_ass, bool use_jinja, - std::string tools_json_arr) + const common_params_tools & tools) { if (use_jinja) { common_chat_inputs inputs; @@ -1807,17 +1840,19 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + if (tools.tools() != nullptr) { + inputs.tools = *tools.tools(); + } + auto choice = tools.choice(); + if (std::holds_alternative(choice)) { + inputs.tool_choice = std::get(choice); - if (! tools_json_arr.empty()) { - try { - inputs.tools = tools_json_arr; - - } catch (const json::exception & err) { - LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n", - tools_json_arr.c_str(), err.what()); + } else { + auto choice_ptr = std::get(choice); + if (choice_ptr != nullptr) { + inputs.tool_choice = *choice_ptr; } } - inputs.messages = messages; inputs.add_generation_prompt = add_ass; return common_chat_params_init(tmpl, inputs).prompt; @@ -1858,11 +1893,11 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - std::string tools_json_arr) + const common_params_tools & tools) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1871,7 +1906,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index 5058d654c..a1925d774 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,9 @@ #include #include #include +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -202,6 +205,31 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +class common_params_tools { +public: + using json = nlohmann::ordered_json; + using json_ptr = std::shared_ptr; + using tool_choice_t = std::variant; + + common_params_tools(std::string tools = "", + std::string choice = "auto"); + + common_params_tools(const common_params_tools & other) = default; + common_params_tools(common_params_tools && other) noexcept = default; + common_params_tools & operator=(const common_params_tools & other) = default; + common_params_tools & operator=(common_params_tools && other) noexcept = default; + + void tools(std::string tools); + const json * tools() const { return tools_.get(); } + + void choice(std::string choice); + const tool_choice_t & choice() const { return tool_choice_; } + +private: + json_ptr tools_; + tool_choice_t tool_choice_; +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -346,7 +374,8 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - std::string jinja_tools = ""; + common_params_tools jinja_tools; + std::vector api_keys; std::string ssl_file_key = ""; // NOLINT @@ -649,7 +678,7 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - std::string tools_json_arr = std::string()); + const common_params_tools & tools = common_params_tools()); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -658,7 +687,7 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - std::string tools_json_arr = std::string()); + const common_params_tools & tools = common_params_tools()); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index dd0f65bfe..d19a24331 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,14 +263,14 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, - const std::string & content, - const std::string & tools = std::string()) { + auto chat_add_and_format = [&chat_msgs, &chat_templates]( + const std::string & role, const std::string & content, + const common_params_tools & tools = common_params_tools()) + { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, - chat_msgs, new_msg, - role == "user", - g_params->use_jinja, tools); + auto formatted = common_chat_format_single( + *chat_templates.template_default, chat_msgs, new_msg, role == "user", + g_params->use_jinja, tools); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); From cd16957f7d938bd5a334d2c23e4d1a7e604436cf Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 11:54:48 -0400 Subject: [PATCH 4/7] Add variant include --- common/common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/common/common.h b/common/common.h index a1925d774..417eb546c 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,7 @@ #include #include #include +#include // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" From 4e8beb0c533949e7433375444344f3777deae1ed Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 13:21:28 -0400 Subject: [PATCH 5/7] Reset tools when empty string provided --- common/common.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index d9657bc85..de2adba1f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1775,6 +1775,10 @@ common_params_tools::common_params_tools(std::string tools, std::string choice) } void common_params_tools::tools(std::string tools) { + if (tools.empty()) { + tools_.reset(); + return; + } try { tools_ = std::make_shared(json::parse(tools)); if (! tools_->is_array()) { From 34370803fbbb84a25ca1763831aba1e45d7f3dfd Mon Sep 17 00:00:00 2001 From: Mason M Date: Tue, 4 Feb 2025 15:00:58 -0400 Subject: [PATCH 6/7] Pass template group to common_chat_apply_template --- common/common.cpp | 18 ++++++++++++------ common/common.h | 6 +++--- examples/main/main.cpp | 8 ++++---- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 4 ++-- tests/test-chat-template.cpp | 6 ++++-- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index de2adba1f..8d6e8b0cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1831,12 +1831,15 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { } std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, const common_params_tools & tools) { + const auto & tmpl_selected = + tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; + if (use_jinja) { common_chat_inputs inputs; @@ -1844,9 +1847,11 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + if (tools.tools() != nullptr) { inputs.tools = *tools.tools(); } + auto choice = tools.choice(); if (std::holds_alternative(choice)) { inputs.tool_choice = std::get(choice); @@ -1857,9 +1862,10 @@ std::string common_chat_apply_template( inputs.tool_choice = *choice_ptr; } } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; + return common_chat_params_init(tmpl_selected, inputs).prompt; } int alloc_size = 0; @@ -1872,7 +1878,7 @@ std::string common_chat_apply_template( std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1884,7 +1890,7 @@ std::string common_chat_apply_template( // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); @@ -1892,7 +1898,7 @@ std::string common_chat_apply_template( } std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -1916,7 +1922,7 @@ std::string common_chat_format_single( return ss.str(); } -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant", {}}, {"user", "Hello", {}}, diff --git a/common/common.h b/common/common.h index 417eb546c..c3f4f66ee 100644 --- a/common/common.h +++ b/common/common.h @@ -675,7 +675,7 @@ struct common_chat_templates { // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & chat, bool add_ass, bool use_jinja, @@ -683,7 +683,7 @@ std::string common_chat_apply_template( // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -692,7 +692,7 @@ std::string common_chat_format_single( // Returns an example of formatted chat std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); + const common_chat_templates & tmpl, bool use_jinja); common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d19a24331..d562522a4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -268,9 +268,9 @@ int main(int argc, char ** argv) { const common_params_tools & tools = common_params_tools()) { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single( - *chat_templates.template_default, chat_msgs, new_msg, role == "user", - g_params->use_jinja, tools); + + auto formatted = common_chat_format_single(chat_templates, chat_msgs, + new_msg, role == "user", g_params->use_jinja, tools); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e0acc4705..3ceba0558 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4468,7 +4468,7 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f97df5fd..f1b4ee5b5 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -348,7 +348,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { +inline std::string format_chat(const common_chat_templates & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse( llama_params["stop"].push_back(stop); } } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + llama_params["prompt"] = format_chat(chat_templates, body.at("messages")); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1d..022205b7b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -339,7 +339,8 @@ int main(void) { common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); @@ -366,7 +367,8 @@ int main(void) { common_chat_msg new_msg{"user", "How are you", {}}; auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); From a726adaef7be99eafbfc873420e337d6f9f79b1d Mon Sep 17 00:00:00 2001 From: Mason M Date: Wed, 5 Feb 2025 12:07:22 -0400 Subject: [PATCH 7/7] Copy sampler parameters from chat template --- common/common.cpp | 54 ++++++++++++++++++++++++++++++++++++++---- common/common.h | 11 +++++++-- examples/main/main.cpp | 10 +++++--- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 15d025846..13e03e88f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } +static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams) +{ + GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab); + + auto & dst = *update_sparams->sparams; + auto vocab = update_sparams->vocab; + + dst.grammar = src.grammar; + dst.grammar_lazy = src.grammar_lazy; + + for (const auto & trigger : src.grammar_triggers) { + auto ids = common_tokenize(vocab, trigger.word, false, true); + + if (ids.size() == 1) { + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + dst.grammar_trigger_tokens.push_back(ids[0]); + dst.preserved_tokens.insert(ids[0]); + continue; + } + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + dst.grammar_trigger_words.push_back(trigger); + } + + for (const auto & preserved : src.preserved_tokens) { + auto ids = common_tokenize(vocab, preserved, false, true); + if (ids.size() == 1) { + LOG_DBG("Preserved token: %d\n", ids[0]); + dst.preserved_tokens.insert(ids[0]); + + } else { + // This may happen when using a tool call style meant for a model + // with special tokens to preserve on a model without said tokens. + LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", + preserved.c_str()); + } + } +} + std::string common_chat_apply_template( const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, - const common_params_tools & tools) + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) { const auto & tmpl_selected = tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; @@ -1865,7 +1904,11 @@ std::string common_chat_apply_template( inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl_selected, inputs).prompt; + auto chat_params = common_chat_params_init(tmpl_selected, inputs); + if (update_sparams) { + copy_chat_params(chat_params, update_sparams); + } + return chat_params.prompt; } int alloc_size = 0; @@ -1903,11 +1946,12 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools) + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1916,7 +1960,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index c3f4f66ee..f7bdb4e69 100644 --- a/common/common.h +++ b/common/common.h @@ -671,6 +671,11 @@ struct common_chat_templates { std::unique_ptr template_tool_use; }; +struct common_chat_sampling_updater { + common_params_sampling * sparams; + const llama_vocab * vocab; +}; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error @@ -679,7 +684,8 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools()); + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -688,7 +694,8 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools()); + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d562522a4..0ea614bcd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,14 +263,18 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates]( + auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab]( const std::string & role, const std::string & content, const common_params_tools & tools = common_params_tools()) { + bool add_ass = (role == "user"); + common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(chat_templates, chat_msgs, - new_msg, role == "user", g_params->use_jinja, tools); + common_chat_sampling_updater updater{&sparams, vocab}; + auto formatted = + common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja, + tools, &updater); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str());