diff --git a/common/arg.cpp b/common/arg.cpp index f5e9b294f..b916dcfcc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1993,6 +1993,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"--tools"}, "JINJA_TOOLS", + string_format( + "set to a JSON array of tool definitions used for assistant function-calling " + "(requires --jinja)"), + [](common_params ¶ms, const std::string & value) { + params.jinja_tools = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA_TOOLS")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index 6c81d18f9..de4a52905 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1797,13 +1797,27 @@ std::string common_chat_apply_template( const common_chat_template & tmpl, const std::vector & msgs, bool add_ass, - bool use_jinja) { + bool use_jinja, + std::string tools_json_arr) +{ if (use_jinja) { + common_chat_inputs inputs; + auto messages = json::array(); for (const auto & msg : msgs) { messages.push_back({{"role", msg.role}, {"content", msg.content}}); } - common_chat_inputs inputs; + + if (! tools_json_arr.empty()) { + try { + inputs.tools = tools_json_arr; + + } catch (const json::exception & err) { + LOG_WRN("Failed to parse tools JSON array \"%s\": \"%s\". Ignoring tools...\n", + tools_json_arr.c_str(), err.what()); + } + } + inputs.messages = messages; inputs.add_generation_prompt = add_ass; return common_chat_params_init(tmpl, inputs).prompt; @@ -1843,9 +1857,13 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja) { + bool use_jinja, + std::string tools_json_arr) +{ std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); + auto fmt_past_msg = past_msg.empty() ? "" + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools_json_arr); + std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -2182,4 +2200,3 @@ common_control_vector_data common_control_vector_load(const std::vector api_keys; std::string ssl_file_key = ""; // NOLINT @@ -645,7 +645,8 @@ std::string common_chat_apply_template( const common_chat_template & tmpl, const std::vector & chat, bool add_ass, - bool use_jinja); + bool use_jinja, + std::string tools_json_arr = std::string()); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -653,7 +654,8 @@ std::string common_chat_format_single( const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, - bool use_jinja); + bool use_jinja, + std::string tools_json_arr = std::string()); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d3542..dd0f65bfe 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,20 +263,27 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, + const std::string & content, + const std::string & tools = std::string()) { common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); + auto formatted = common_chat_format_single(*chat_templates.template_default, + chat_msgs, new_msg, + role == "user", + g_params->use_jinja, tools); + chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; { - auto prompt = (params.conversation_mode && params.enable_chat_template) - // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) - // otherwise use the prompt as is + std::string system_prompt (params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt); + bool use_conversation_prompt (params.conversation_mode && params.enable_chat_template); + auto prompt = use_conversation_prompt ? + chat_add_and_format("system", system_prompt, params.jinja_tools) : params.prompt; + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); embd_inp = common_tokenize(ctx, prompt, true, true);