server: ultra basic tools, tool_choice, tool_calls support

This commit is contained in:
ochafik 2024-05-22 04:15:14 +01:00
parent 793f4ff3f5
commit a1c4aac384
2 changed files with 91 additions and 15 deletions

View file

@ -66,6 +66,8 @@ def step_server_config(context, server_fqdn, server_port):
context.server_seed = None context.server_seed = None
context.user_api_key = None context.user_api_key = None
context.response_format = None context.response_format = None
context.tools = None
context.tool_choice = None
context.temperature = None context.temperature = None
context.tasks_result = [] context.tasks_result = []
@ -337,6 +339,13 @@ def step_max_tokens(context, max_tokens):
def step_response_format(context, response_format): def step_response_format(context, response_format):
context.response_format = json.loads(response_format) context.response_format = json.loads(response_format)
@step('tools {tools}')
def step_tools(context, tools):
context.tools = json.loads(tools)
@step('tool choice {tool_choice}')
def step_tool_choice(context, tool_choice):
context.tool_choice = tool_choice
@step('{temperature:f} temperature') @step('{temperature:f} temperature')
def step_temperature(context, temperature): def step_temperature(context, temperature):
@ -471,6 +480,11 @@ async def step_oai_chat_completions(context, api_error):
response_format=context.response_format response_format=context.response_format
if hasattr(context, 'response_format') else None, if hasattr(context, 'response_format') else None,
tools=context.tools
if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None, if hasattr(context, 'user_api_key') else None,
@ -541,6 +555,9 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format response_format=context.response_format
if hasattr(context, 'response_format') else None, if hasattr(context, 'response_format') else None,
tools=context.tools
if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@ -554,16 +571,18 @@ async def step_oai_chat_completions(context):
context.base_url, context.base_url,
'/chat/completions', '/chat/completions',
True, # async_client True, # async_client
model=context.model model=context.model,
if hasattr(context, 'model') else None, # if hasattr(context, 'model') else None,
n_predict=context.n_predict n_predict=context.n_predict,
if hasattr(context, 'n_predict') else None, # if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format response_format=context.response_format,
if hasattr(context, 'response_format') else None, # if hasattr(context, 'response_format') else None,
user_api_key=context.user_api_key tools=context.tools,# if hasattr(context, 'tools') else None,
if hasattr(context, 'user_api_key') else None) tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
user_api_key=context.user_api_key)
# if hasattr(context, 'user_api_key') else None)
@step('all prompts are predicted') @step('all prompts are predicted')
@ -908,6 +927,8 @@ async def oai_chat_completions(user_prompt,
n_predict=None, n_predict=None,
enable_streaming=None, enable_streaming=None,
response_format=None, response_format=None,
tools=None,
tool_choice=None,
user_api_key=None, user_api_key=None,
expect_api_error=None): expect_api_error=None):
if debug: if debug:
@ -935,6 +956,10 @@ async def oai_chat_completions(user_prompt,
} }
if response_format is not None: if response_format is not None:
payload['response_format'] = response_format payload['response_format'] = response_format
if tools is not None:
payload['tools'] = tools
if tool_choice is not None:
payload['tool_choice'] = tool_choice
completion_response = { completion_response = {
'content': '', 'content': '',
'timings': { 'timings': {
@ -996,6 +1021,8 @@ async def oai_chat_completions(user_prompt,
max_tokens=n_predict, max_tokens=n_predict,
stream=enable_streaming, stream=enable_streaming,
response_format=payload.get('response_format'), response_format=payload.get('response_format'),
tools=payload.get('tools'),
tool_choice=payload.get('tool_choice'),
seed=seed, seed=seed,
temperature=payload['temperature'] temperature=payload['temperature']
) )

View file

@ -12,6 +12,7 @@
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <random> #include <random>
#include <regex>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@ -410,19 +411,40 @@ static json oaicompat_completion_params_parse(
} }
} else if (body.contains("tools") && body["tools"].is_array()) { } else if (body.contains("tools") && body["tools"].is_array()) {
const auto & tools = body["tools"]; const auto & tools = body["tools"];
llama_params["grammar"] = tool_call_grammar(tools); bool built_grammar = false;
bool allow_parallel_calls = false;
bool allow_content = true;
if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"] != "auto") {
std::string tool_choice = body["tool_choice"];
if (tool_choice == "required") {
allow_content = false;
} else {
for (const auto & tool : tools) {
if (tool["name"] == tool_choice) {
llama_params["grammar"] = tool_call_grammar(json::array({ tool }), allow_parallel_calls, /* allow_content= */ false);
built_grammar = true;
break;
}
}
}
}
if (!built_grammar) {
llama_params["grammar"] = tool_call_grammar(tools, allow_parallel_calls, allow_content);
}
// TODO: pass a template file.
extra_system_message = (std::stringstream() extra_system_message = (std::stringstream()
<< "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. " << "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. "
<< "You may call one or more functions to assist with the user query. " << "You may call one or more functions to assist with the user query. "
<< "Don't make assumptions about what values to plug into functions. " // << "Don't make assumptions about what values to plug into functions. "
<< "Here are the available tools: <tools>" << "Here are the available tools: <tools>"
<< tools.dump().c_str() << tools.dump(2).c_str()
<< "</tools>\n" << "</tools>\n"
// << "To call a tool give a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
<< "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:" << "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
<< "<tool_call>" << "<tool_call>"
<< "{\"arguments\": <args-dict>, \"name\": <function-name>}" << "{\"name\": <function-name>, \"arguments\": <args-dict>}"
<< "</tool_call>" << "</tool_call>"
<< "Don't explain which tools you're going to call, just call them."
).str(); ).str();
} }
@ -451,7 +473,7 @@ static json oaicompat_completion_params_parse(
} }
// Params supported by OAI but unsupported by llama.cpp // Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tool_choice" }; static const std::vector<std::string> unsupported_params;// { "tool_choice" };
for (auto & param : unsupported_params) { for (auto & param : unsupported_params) {
if (body.contains(param)) { if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param); throw std::runtime_error("Unsupported param: " + param);
@ -478,10 +500,36 @@ static json format_final_response_oaicompat(const json & request, json result, c
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string("")); std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stopped_word || stopped_eos) { if (stopped_word || stopped_eos) {
finish_reason = "stop"; finish_reason = "stop";
} }
json tool_calls;
json message_content;
if (request.contains("tools")) {
std::regex pattern("<tool_call>(.*?)</tool_call>");
std::sregex_iterator iter(content.begin(), content.end(), pattern);
std::sregex_iterator end;
while (iter != end) {
std::smatch match = *iter;
auto call = json::parse(match[1].str());
if (tool_calls.is_null()) {
tool_calls = json::array();
}
tool_calls.push_back({
{"function", {
{"name", call["name"]},
{"arguments", call["arguments"].dump()},
}},
});
finish_reason = "tool_calls";
++iter;
}
}
if (tool_calls.is_null()) {
message_content = content;
}
json choices = json choices =
streaming ? json::array({json{{"finish_reason", finish_reason}, streaming ? json::array({json{{"finish_reason", finish_reason},
@ -489,7 +537,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
{"delta", json::object()}}}) {"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason}, : json::array({json{{"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", json{{"content", content}, {"message", json{{"content", message_content},
{"tool_calls", tool_calls},
{"role", "assistant"}}}}}); {"role", "assistant"}}}}});
std::time_t t = std::time(0); std::time_t t = std::time(0);