diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 26d9359d7..5c59d079b 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -66,6 +66,8 @@ def step_server_config(context, server_fqdn, server_port): context.server_seed = None context.user_api_key = None context.response_format = None + context.tools = None + context.tool_choice = None context.temperature = None context.tasks_result = [] @@ -337,6 +339,13 @@ def step_max_tokens(context, max_tokens): def step_response_format(context, response_format): context.response_format = json.loads(response_format) +@step('tools {tools}') +def step_tools(context, tools): + context.tools = json.loads(tools) + +@step('tool choice {tool_choice}') +def step_tool_choice(context, tool_choice): + context.tool_choice = tool_choice @step('{temperature:f} temperature') def step_temperature(context, temperature): @@ -471,6 +480,11 @@ async def step_oai_chat_completions(context, api_error): response_format=context.response_format if hasattr(context, 'response_format') else None, + tools=context.tools + if hasattr(context, 'tools') else None, + + tool_choice=context.tool_choice, + user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, @@ -541,6 +555,9 @@ async def step_oai_chat_completions(context): if hasattr(context, 'enable_streaming') else None, response_format=context.response_format if hasattr(context, 'response_format') else None, + tools=context.tools + if hasattr(context, 'tools') else None, + tool_choice=context.tool_choice, user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -554,16 +571,18 @@ async def step_oai_chat_completions(context): context.base_url, '/chat/completions', True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, + model=context.model, + # if hasattr(context, 'model') else None, + n_predict=context.n_predict, + # if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, - response_format=context.response_format - if hasattr(context, 'response_format') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) + response_format=context.response_format, + # if hasattr(context, 'response_format') else None, + tools=context.tools,# if hasattr(context, 'tools') else None, + tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None, + user_api_key=context.user_api_key) + # if hasattr(context, 'user_api_key') else None) @step('all prompts are predicted') @@ -908,6 +927,8 @@ async def oai_chat_completions(user_prompt, n_predict=None, enable_streaming=None, response_format=None, + tools=None, + tool_choice=None, user_api_key=None, expect_api_error=None): if debug: @@ -935,6 +956,10 @@ async def oai_chat_completions(user_prompt, } if response_format is not None: payload['response_format'] = response_format + if tools is not None: + payload['tools'] = tools + if tool_choice is not None: + payload['tool_choice'] = tool_choice completion_response = { 'content': '', 'timings': { @@ -996,6 +1021,8 @@ async def oai_chat_completions(user_prompt, max_tokens=n_predict, stream=enable_streaming, response_format=payload.get('response_format'), + tools=payload.get('tools'), + tool_choice=payload.get('tool_choice'), seed=seed, temperature=payload['temperature'] ) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3cdbe0a09..5baf7b554 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" @@ -410,19 +411,40 @@ static json oaicompat_completion_params_parse( } } else if (body.contains("tools") && body["tools"].is_array()) { const auto & tools = body["tools"]; - llama_params["grammar"] = tool_call_grammar(tools); - + bool built_grammar = false; + bool allow_parallel_calls = false; + bool allow_content = true; + if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"] != "auto") { + std::string tool_choice = body["tool_choice"]; + if (tool_choice == "required") { + allow_content = false; + } else { + for (const auto & tool : tools) { + if (tool["name"] == tool_choice) { + llama_params["grammar"] = tool_call_grammar(json::array({ tool }), allow_parallel_calls, /* allow_content= */ false); + built_grammar = true; + break; + } + } + } + } + if (!built_grammar) { + llama_params["grammar"] = tool_call_grammar(tools, allow_parallel_calls, allow_content); + } + // TODO: pass a template file. extra_system_message = (std::stringstream() << "You are a function calling AI model. You are provided with function signatures within XML tags. " << "You may call one or more functions to assist with the user query. " - << "Don't make assumptions about what values to plug into functions. " + // << "Don't make assumptions about what values to plug into functions. " << "Here are the available tools: " - << tools.dump().c_str() + << tools.dump(2).c_str() << "\n" + // << "To call a tool give a json object with function name and arguments within XML tags as follows:" << "For each function call return a json object with function name and arguments within XML tags as follows:" << "" - << "{\"arguments\": , \"name\": }" + << "{\"name\": , \"arguments\": }" << "" + << "Don't explain which tools you're going to call, just call them." ).str(); } @@ -451,7 +473,7 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tool_choice" }; + static const std::vector unsupported_params;// { "tool_choice" }; for (auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); @@ -478,10 +500,36 @@ static json format_final_response_oaicompat(const json & request, json result, c int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); std::string content = json_value(result, "content", std::string("")); + std::string finish_reason = "length"; if (stopped_word || stopped_eos) { finish_reason = "stop"; } + json tool_calls; + json message_content; + if (request.contains("tools")) { + std::regex pattern("(.*?)"); + std::sregex_iterator iter(content.begin(), content.end(), pattern); + std::sregex_iterator end; + while (iter != end) { + std::smatch match = *iter; + auto call = json::parse(match[1].str()); + if (tool_calls.is_null()) { + tool_calls = json::array(); + } + tool_calls.push_back({ + {"function", { + {"name", call["name"]}, + {"arguments", call["arguments"].dump()}, + }}, + }); + finish_reason = "tool_calls"; + ++iter; + } + } + if (tool_calls.is_null()) { + message_content = content; + } json choices = streaming ? json::array({json{{"finish_reason", finish_reason}, @@ -489,7 +537,8 @@ static json format_final_response_oaicompat(const json & request, json result, c {"delta", json::object()}}}) : json::array({json{{"finish_reason", finish_reason}, {"index", 0}, - {"message", json{{"content", content}, + {"message", json{{"content", message_content}, + {"tool_calls", tool_calls}, {"role", "assistant"}}}}}); std::time_t t = std::time(0);