From 5f06d37baf371626fd19bb79627c15b934ed0509 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 30 Aug 2024 21:40:49 +0200 Subject: [PATCH] add --tool-call argument --- common/common.cpp | 6 ++++++ common/common.h | 1 + examples/server/server.cpp | 17 +++++++++++++++-- examples/server/utils.hpp | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 28ec4f5fc..478787edf 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -428,6 +428,7 @@ void gpt_params_parse_from_env(gpt_params & params) { get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching); get_env("LLAMA_ARG_HOST", params.hostname); get_env("LLAMA_ARG_PORT", params.port); + get_env("LLAMA_ARG_TOOL_CALLS", params.enable_tool_calls); } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { @@ -1046,6 +1047,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.lora_init_without_apply = true; return true; } + if (arg == "--tool-call" || arg == "--tool-calls") { + params.enable_tool_calls = true; + return true; + } if (arg == "--control-vector") { CHECK_ARG params.control_vectors.push_back({ 1.0f, argv[i], }); @@ -2036,6 +2041,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"}); + options.push_back({ "server", " --tool-call(s)", "enable OAI tool calls for chat completion endpoint (default: %s)", params.enable_tool_calls ? "enabled" : "disabled"}); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); diff --git a/common/common.h b/common/common.h index db0800432..fe5f485e5 100644 --- a/common/common.h +++ b/common/common.h @@ -221,6 +221,7 @@ struct gpt_params { std::string chat_template = ""; std::string system_prompt = ""; bool enable_chat_template = true; + bool enable_tool_calls = false; std::vector api_keys; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 73d008839..ccc45fb6d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3071,6 +3071,7 @@ int main(int argc, char ** argv) { if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) { body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools")); + body.erase(body.find("tools")); } json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template); @@ -3441,14 +3442,26 @@ int main(int argc, char ** argv) { } // decide if we can enable tool calls - ctx_server.tool_format = get_tool_format(ctx_server.ctx); + bool tool_call_support = false; + if (ctx_server.params.enable_tool_calls) { + ctx_server.tool_format = get_tool_format(ctx_server.ctx); + tool_call_support = ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED; + if (tool_call_support) { + LOG_WARNING("Tool call is EXPERIMENTAL and maybe unstable. Use with your own risk", {}); + } else { + LOG_ERROR("Tool call is not supported for this model. Please remove --tool-call or use with a supported model", {}); + clean_up(); + t.join(); + return 1; + } + } // print sample chat example to make it clear which template is used { LOG_INFO("chat template", { {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, {"built_in", params.chat_template.empty()}, - {"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED }, + {"tool_call_support", tool_call_support}, }); } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e46f7b032..253f8d42f 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -397,7 +397,7 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tool_choice" }; + static const std::vector unsupported_params { "tools", "tool_choice" }; for (auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param);