server: ultra basic tools, tool_choice, tool_calls support

This commit is contained in:
ochafik 2024-05-22 04:15:14 +01:00
parent 793f4ff3f5
commit a1c4aac384
2 changed files with 91 additions and 15 deletions

View file

@ -66,6 +66,8 @@ def step_server_config(context, server_fqdn, server_port):
context.server_seed = None
context.user_api_key = None
context.response_format = None
context.tools = None
context.tool_choice = None
context.temperature = None
context.tasks_result = []
@ -337,6 +339,13 @@ def step_max_tokens(context, max_tokens):
def step_response_format(context, response_format):
context.response_format = json.loads(response_format)
@step('tools {tools}')
def step_tools(context, tools):
context.tools = json.loads(tools)
@step('tool choice {tool_choice}')
def step_tool_choice(context, tool_choice):
context.tool_choice = tool_choice
@step('{temperature:f} temperature')
def step_temperature(context, temperature):
@ -471,6 +480,11 @@ async def step_oai_chat_completions(context, api_error):
response_format=context.response_format
if hasattr(context, 'response_format') else None,
tools=context.tools
if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None,
@ -541,6 +555,9 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
tools=context.tools
if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@ -554,16 +571,18 @@ async def step_oai_chat_completions(context):
context.base_url,
'/chat/completions',
True, # async_client
model=context.model
if hasattr(context, 'model') else None,
n_predict=context.n_predict
if hasattr(context, 'n_predict') else None,
model=context.model,
# if hasattr(context, 'model') else None,
n_predict=context.n_predict,
# if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
response_format=context.response_format,
# if hasattr(context, 'response_format') else None,
tools=context.tools,# if hasattr(context, 'tools') else None,
tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
user_api_key=context.user_api_key)
# if hasattr(context, 'user_api_key') else None)
@step('all prompts are predicted')
@ -908,6 +927,8 @@ async def oai_chat_completions(user_prompt,
n_predict=None,
enable_streaming=None,
response_format=None,
tools=None,
tool_choice=None,
user_api_key=None,
expect_api_error=None):
if debug:
@ -935,6 +956,10 @@ async def oai_chat_completions(user_prompt,
}
if response_format is not None:
payload['response_format'] = response_format
if tools is not None:
payload['tools'] = tools
if tool_choice is not None:
payload['tool_choice'] = tool_choice
completion_response = {
'content': '',
'timings': {
@ -996,6 +1021,8 @@ async def oai_chat_completions(user_prompt,
max_tokens=n_predict,
stream=enable_streaming,
response_format=payload.get('response_format'),
tools=payload.get('tools'),
tool_choice=payload.get('tool_choice'),
seed=seed,
temperature=payload['temperature']
)

View file

@ -12,6 +12,7 @@
#include <vector>
#include <sstream>
#include <random>
#include <regex>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
@ -410,19 +411,40 @@ static json oaicompat_completion_params_parse(
}
} else if (body.contains("tools") && body["tools"].is_array()) {
const auto & tools = body["tools"];
llama_params["grammar"] = tool_call_grammar(tools);
bool built_grammar = false;
bool allow_parallel_calls = false;
bool allow_content = true;
if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"] != "auto") {
std::string tool_choice = body["tool_choice"];
if (tool_choice == "required") {
allow_content = false;
} else {
for (const auto & tool : tools) {
if (tool["name"] == tool_choice) {
llama_params["grammar"] = tool_call_grammar(json::array({ tool }), allow_parallel_calls, /* allow_content= */ false);
built_grammar = true;
break;
}
}
}
}
if (!built_grammar) {
llama_params["grammar"] = tool_call_grammar(tools, allow_parallel_calls, allow_content);
}
// TODO: pass a template file.
extra_system_message = (std::stringstream()
<< "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. "
<< "You may call one or more functions to assist with the user query. "
<< "Don't make assumptions about what values to plug into functions. "
// << "Don't make assumptions about what values to plug into functions. "
<< "Here are the available tools: <tools>"
<< tools.dump().c_str()
<< tools.dump(2).c_str()
<< "</tools>\n"
// << "To call a tool give a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
<< "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
<< "<tool_call>"
<< "{\"arguments\": <args-dict>, \"name\": <function-name>}"
<< "{\"name\": <function-name>, \"arguments\": <args-dict>}"
<< "</tool_call>"
<< "Don't explain which tools you're going to call, just call them."
).str();
}
@ -451,7 +473,7 @@ static json oaicompat_completion_params_parse(
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tool_choice" };
static const std::vector<std::string> unsupported_params;// { "tool_choice" };
for (auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
@ -478,10 +500,36 @@ static json format_final_response_oaicompat(const json & request, json result, c
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json tool_calls;
json message_content;
if (request.contains("tools")) {
std::regex pattern("<tool_call>(.*?)</tool_call>");
std::sregex_iterator iter(content.begin(), content.end(), pattern);
std::sregex_iterator end;
while (iter != end) {
std::smatch match = *iter;
auto call = json::parse(match[1].str());
if (tool_calls.is_null()) {
tool_calls = json::array();
}
tool_calls.push_back({
{"function", {
{"name", call["name"]},
{"arguments", call["arguments"].dump()},
}},
});
finish_reason = "tool_calls";
++iter;
}
}
if (tool_calls.is_null()) {
message_content = content;
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
@ -489,7 +537,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"message", json{{"content", message_content},
{"tool_calls", tool_calls},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);