add --tool-call argument
This commit is contained in:
parent
7e017cfbc8
commit
5f06d37baf
4 changed files with 23 additions and 3 deletions
|
@ -428,6 +428,7 @@ void gpt_params_parse_from_env(gpt_params & params) {
|
||||||
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
|
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
|
||||||
get_env("LLAMA_ARG_HOST", params.hostname);
|
get_env("LLAMA_ARG_HOST", params.hostname);
|
||||||
get_env("LLAMA_ARG_PORT", params.port);
|
get_env("LLAMA_ARG_PORT", params.port);
|
||||||
|
get_env("LLAMA_ARG_TOOL_CALLS", params.enable_tool_calls);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
|
@ -1046,6 +1047,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.lora_init_without_apply = true;
|
params.lora_init_without_apply = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--tool-call" || arg == "--tool-calls") {
|
||||||
|
params.enable_tool_calls = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--control-vector") {
|
if (arg == "--control-vector") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
params.control_vectors.push_back({ 1.0f, argv[i], });
|
params.control_vectors.push_back({ 1.0f, argv[i], });
|
||||||
|
@ -2036,6 +2041,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
||||||
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
|
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
|
||||||
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
|
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
|
||||||
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
|
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
|
||||||
|
options.push_back({ "server", " --tool-call(s)", "enable OAI tool calls for chat completion endpoint (default: %s)", params.enable_tool_calls ? "enabled" : "disabled"});
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
options.push_back({ "logging" });
|
options.push_back({ "logging" });
|
||||||
|
|
|
@ -221,6 +221,7 @@ struct gpt_params {
|
||||||
std::string chat_template = "";
|
std::string chat_template = "";
|
||||||
std::string system_prompt = "";
|
std::string system_prompt = "";
|
||||||
bool enable_chat_template = true;
|
bool enable_chat_template = true;
|
||||||
|
bool enable_tool_calls = false;
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
|
|
|
@ -3071,6 +3071,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
if (body.contains("tools") && ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED) {
|
||||||
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
|
body["prompt"] = format_chat_with_tool(ctx_server.tool_format, body.at("messages"), body.at("tools"));
|
||||||
|
body.erase(body.find("tools"));
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, body, params.chat_template);
|
||||||
|
@ -3441,14 +3442,26 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// decide if we can enable tool calls
|
// decide if we can enable tool calls
|
||||||
|
bool tool_call_support = false;
|
||||||
|
if (ctx_server.params.enable_tool_calls) {
|
||||||
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
|
ctx_server.tool_format = get_tool_format(ctx_server.ctx);
|
||||||
|
tool_call_support = ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED;
|
||||||
|
if (tool_call_support) {
|
||||||
|
LOG_WARNING("Tool call is EXPERIMENTAL and maybe unstable. Use with your own risk", {});
|
||||||
|
} else {
|
||||||
|
LOG_ERROR("Tool call is not supported for this model. Please remove --tool-call or use with a supported model", {});
|
||||||
|
clean_up();
|
||||||
|
t.join();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
{
|
{
|
||||||
LOG_INFO("chat template", {
|
LOG_INFO("chat template", {
|
||||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||||
{"built_in", params.chat_template.empty()},
|
{"built_in", params.chat_template.empty()},
|
||||||
{"tool_call_enabled", ctx_server.tool_format != LLAMA_TOOL_FORMAT_NOT_SUPPORTED },
|
{"tool_call_support", tool_call_support},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -397,7 +397,7 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
|
||||||
for (auto & param : unsupported_params) {
|
for (auto & param : unsupported_params) {
|
||||||
if (body.contains(param)) {
|
if (body.contains(param)) {
|
||||||
throw std::runtime_error("Unsupported param: " + param);
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue