diff --git a/common/arg.cpp b/common/arg.cpp index b916dcfcc..dd0cc51d9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1993,15 +1993,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( {"--tools"}, "JINJA_TOOLS", - string_format( - "set to a JSON array of tool definitions used for assistant function-calling " - "(requires --jinja)"), + "set to JSON array of tool definitions used for assistant function-calling (requires --jinja)", [](common_params ¶ms, const std::string & value) { - params.jinja_tools = value; - } - ).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS")); + params.jinja_tools.tools(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + + add_opt(common_arg( + {"--tool-choice"}, "JINJA_TOOL_CHOICE", + "set to \"auto\", \"required\", \"none\" or a JSON object specifying a tool function (default: \"auto\")", + [](common_params ¶ms, const std::string & value) { + params.jinja_tools.choice(value); + + }).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index e546f78c2..d9657bc85 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -7,9 +7,6 @@ #include "common.h" #include "log.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" #include "chat.hpp" @@ -1772,6 +1769,42 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Chat template utils // +common_params_tools::common_params_tools(std::string tools, std::string choice) { + this->tools(tools); + this->choice(choice); +} + +void common_params_tools::tools(std::string tools) { + try { + tools_ = std::make_shared(json::parse(tools)); + if (! tools_->is_array()) { + throw std::invalid_argument("tools must be a valid JSON array"); + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + +void common_params_tools::choice(std::string choice) { + try { + if (choice == "auto" || choice == "required" || choice == "none") { + tool_choice_ = std::move(choice); + + } else { + auto choice_ptr = std::make_shared(json::parse(choice)); + tool_choice_ = choice_ptr; + if (! choice_ptr->is_object()) { + throw std::invalid_argument( + "tool choice must be a valid JSON object, \"auto\", \"required\", or \"none\""); + } + } + + } catch (const json::exception & err) { + throw std::invalid_argument(err.what()); + } +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -1798,7 +1831,7 @@ std::string common_chat_apply_template( const std::vector & msgs, bool add_ass, bool use_jinja, - std::string tools_json_arr) + const common_params_tools & tools) { if (use_jinja) { common_chat_inputs inputs; @@ -1807,17 +1840,19 @@ std::string common_chat_apply_template( for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } + if (tools.tools() != nullptr) { + inputs.tools = *tools.tools(); + } + auto choice = tools.choice(); + if (std::holds_alternative(choice)) { + inputs.tool_choice = std::get(choice); - if (! tools_json_arr.empty()) { - try { - inputs.tools = tools_json_arr; - - } catch (const json::exception & err) { - LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n", - tools_json_arr.c_str(), err.what()); + } else { + auto choice_ptr = std::get(choice); + if (choice_ptr != nullptr) { + inputs.tool_choice = *choice_ptr; } } - inputs.messages = messages; inputs.add_generation_prompt = add_ass; return common_chat_params_init(tmpl, inputs).prompt; @@ -1858,11 +1893,11 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - std::string tools_json_arr) + const common_params_tools & tools) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1871,7 +1906,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools_json_arr); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index 5058d654c..a1925d774 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,9 @@ #include #include #include +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -202,6 +205,31 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +class common_params_tools { +public: + using json = nlohmann::ordered_json; + using json_ptr = std::shared_ptr; + using tool_choice_t = std::variant; + + common_params_tools(std::string tools = "", + std::string choice = "auto"); + + common_params_tools(const common_params_tools & other) = default; + common_params_tools(common_params_tools && other) noexcept = default; + common_params_tools & operator=(const common_params_tools & other) = default; + common_params_tools & operator=(common_params_tools && other) noexcept = default; + + void tools(std::string tools); + const json * tools() const { return tools_.get(); } + + void choice(std::string choice); + const tool_choice_t & choice() const { return tool_choice_; } + +private: + json_ptr tools_; + tool_choice_t tool_choice_; +}; + struct common_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size @@ -346,7 +374,8 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - std::string jinja_tools = ""; + common_params_tools jinja_tools; + std::vector api_keys; std::string ssl_file_key = ""; // NOLINT @@ -649,7 +678,7 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - std::string tools_json_arr = std::string()); + const common_params_tools & tools = common_params_tools()); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -658,7 +687,7 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - std::string tools_json_arr = std::string()); + const common_params_tools & tools = common_params_tools()); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index dd0f65bfe..d19a24331 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,14 +263,14 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, - const std::string & content, - const std::string & tools = std::string()) { + auto chat_add_and_format = [&chat_msgs, &chat_templates]( + const std::string & role, const std::string & content, + const common_params_tools & tools = common_params_tools()) + { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, - chat_msgs, new_msg, - role == "user", - g_params->use_jinja, tools); + auto formatted = common_chat_format_single( + *chat_templates.template_default, chat_msgs, new_msg, role == "user", + g_params->use_jinja, tools); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str());