diff --git a/common/arg.cpp b/common/arg.cpp index 152f671ab..ba3074350 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2006,6 +2006,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + + add_opt(common_arg( + {"--tools"}, "JINJA_TOOLS", + "set to JSON array of tool definitions used for assistant function-calling (requires --jinja)", + [](common_params ¶ms, const std::string & value) { + params.jinja_tools.tools(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + + add_opt(common_arg( + {"--tool-choice"}, "JINJA_TOOL_CHOICE", + "set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")", + [](common_params ¶ms, const std::string & value) { + params.jinja_tools.choice(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index 8661e164a..13e03e88f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,9 +7,6 @@ #include "common.h" #include "log.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" #include "chat.hpp" @@ -1772,6 +1769,46 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Chat template utils // +common_params_tools::common_params_tools(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +void common_params_tools::tools(std::string tools) { + if (tools.empty()) { + tools_.reset(); + return; + } + try { + tools_ = std::make_shared(json::parse(tools)); + if (! tools_->is_array()) { + throw std::invalid_argument("tools must be a valid JSON array"); + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void common_params_tools::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + auto choice_ptr = std::make_shared(json::parse(choice)); + tool_choice_ = choice_ptr; + if (! choice_ptr->is_object()) { + throw std::invalid_argument( + "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -1793,20 +1830,85 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } +static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams) +{ + GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab); + + auto & dst = *update_sparams->sparams; + auto vocab = update_sparams->vocab; + + dst.grammar = src.grammar; + dst.grammar_lazy = src.grammar_lazy; + + for (const auto & trigger : src.grammar_triggers) { + auto ids = common_tokenize(vocab, trigger.word, false, true); + + if (ids.size() == 1) { + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + dst.grammar_trigger_tokens.push_back(ids[0]); + dst.preserved_tokens.insert(ids[0]); + continue; + } + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + dst.grammar_trigger_words.push_back(trigger); + } + + for (const auto & preserved : src.preserved_tokens) { + auto ids = common_tokenize(vocab, preserved, false, true); + if (ids.size() == 1) { + LOG_DBG("Preserved token: %d\n", ids[0]); + dst.preserved_tokens.insert(ids[0]); + + } else { + // This may happen when using a tool call style meant for a model + // with special tokens to preserve on a model without said tokens. + LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", + preserved.c_str()); + } + } +} + std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, - bool use_jinja) { + bool use_jinja, + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) +{ + const auto & tmpl_selected = + tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; + if (use_jinja) { + common_chat_inputs inputs; + auto messages = json::array(); for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - common_chat_inputs inputs; + + if (tools.tools() != nullptr) { + inputs.tools = *tools.tools(); + } + + auto choice = tools.choice(); + if (std::holds_alternative(choice)) { + inputs.tool_choice = std::get(choice); + + } else { + auto choice_ptr = std::get(choice); + if (choice_ptr != nullptr) { + inputs.tool_choice = *choice_ptr; + } + } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; + auto chat_params = common_chat_params_init(tmpl_selected, inputs); + if (update_sparams) { + copy_chat_params(chat_params, update_sparams); + } + return chat_params.prompt; } int alloc_size = 0; @@ -1819,7 +1921,7 @@ std::string common_chat_apply_template( std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1831,7 +1933,7 @@ std::string common_chat_apply_template( // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl_selected.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); @@ -1839,13 +1941,18 @@ std::string common_chat_apply_template( } std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja) { + bool use_jinja, + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) +{ std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); + auto fmt_past_msg = past_msg.empty() ? "" + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams); + std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1853,13 +1960,13 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const common_chat_templates & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant", {}}, {"user", "Hello", {}}, @@ -2195,4 +2302,3 @@ common_control_vector_data common_control_vector_load(const std::vector #include #include +#include +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -202,6 +206,31 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +class common_params_tools { +public: + using json = nlohmann::ordered_json; + using json_ptr = std::shared_ptr; + using tool_choice_t = std::variant; + + common_params_tools(std::string tools = "", + std::string choice = "auto"); + + common_params_tools(const common_params_tools & other) = default; + common_params_tools(common_params_tools && other) noexcept = default; + common_params_tools & operator=(const common_params_tools & other) = default; + common_params_tools & operator=(common_params_tools && other) noexcept = default; + + void tools(std::string tools); + const json * tools() const { return tools_.get(); } + + void choice(std::string choice); + const tool_choice_t & choice() const { return tool_choice_; } + +private: + json_ptr tools_; + tool_choice_t tool_choice_; +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -346,6 +375,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; + common_params_tools jinja_tools; std::vector api_keys; @@ -641,26 +671,35 @@ struct common_chat_templates { std::unique_ptr template_tool_use; }; +struct common_chat_sampling_updater { + common_params_sampling * sparams; + const llama_vocab * vocab; +}; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & chat, bool add_ass, - bool use_jinja); + bool use_jinja, + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( - const common_chat_template & tmpl, + const common_chat_templates & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja); + bool use_jinja, + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); + const common_chat_templates & tmpl, bool use_jinja); common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d3542..0ea614bcd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -263,20 +263,31 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab]( + const std::string & role, const std::string & content, + const common_params_tools & tools = common_params_tools()) + { + bool add_ass = (role == "user"); + common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); + + common_chat_sampling_updater updater{&sparams, vocab}; + auto formatted = + common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja, + tools, &updater); + chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; { - auto prompt = (params.conversation_mode && params.enable_chat_template) - // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) - // otherwise use the prompt as is + std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt); + bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template); + auto prompt = use_conversation_prompt ? + chat_add_and_format("system", system_prompt, params.jinja_tools) : params.prompt; + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); embd_inp = common_tokenize(ctx, prompt, true, true); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0718806c8..1f6230c4c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4479,7 +4479,7 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f97df5fd..f1b4ee5b5 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -348,7 +348,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { +inline std::string format_chat(const common_chat_templates & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -663,7 +663,7 @@ static json oaicompat_completion_params_parse( llama_params["stop"].push_back(stop); } } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + llama_params["prompt"] = format_chat(chat_templates, body.at("messages")); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1d..022205b7b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -339,7 +339,8 @@ int main(void) { common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); @@ -366,7 +367,8 @@ int main(void) { common_chat_msg new_msg{"user", "How are you", {}}; auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); + common_chat_templates tmpl; + tmpl.template_default.reset(new common_chat_template(tmpl_str, "", "")); auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n");