server: ultra basic tools, tool_choice, tool_calls support
This commit is contained in:
parent
793f4ff3f5
commit
a1c4aac384
2 changed files with 91 additions and 15 deletions
|
@ -66,6 +66,8 @@ def step_server_config(context, server_fqdn, server_port):
|
||||||
context.server_seed = None
|
context.server_seed = None
|
||||||
context.user_api_key = None
|
context.user_api_key = None
|
||||||
context.response_format = None
|
context.response_format = None
|
||||||
|
context.tools = None
|
||||||
|
context.tool_choice = None
|
||||||
context.temperature = None
|
context.temperature = None
|
||||||
|
|
||||||
context.tasks_result = []
|
context.tasks_result = []
|
||||||
|
@ -337,6 +339,13 @@ def step_max_tokens(context, max_tokens):
|
||||||
def step_response_format(context, response_format):
|
def step_response_format(context, response_format):
|
||||||
context.response_format = json.loads(response_format)
|
context.response_format = json.loads(response_format)
|
||||||
|
|
||||||
|
@step('tools {tools}')
|
||||||
|
def step_tools(context, tools):
|
||||||
|
context.tools = json.loads(tools)
|
||||||
|
|
||||||
|
@step('tool choice {tool_choice}')
|
||||||
|
def step_tool_choice(context, tool_choice):
|
||||||
|
context.tool_choice = tool_choice
|
||||||
|
|
||||||
@step('{temperature:f} temperature')
|
@step('{temperature:f} temperature')
|
||||||
def step_temperature(context, temperature):
|
def step_temperature(context, temperature):
|
||||||
|
@ -471,6 +480,11 @@ async def step_oai_chat_completions(context, api_error):
|
||||||
response_format=context.response_format
|
response_format=context.response_format
|
||||||
if hasattr(context, 'response_format') else None,
|
if hasattr(context, 'response_format') else None,
|
||||||
|
|
||||||
|
tools=context.tools
|
||||||
|
if hasattr(context, 'tools') else None,
|
||||||
|
|
||||||
|
tool_choice=context.tool_choice,
|
||||||
|
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None,
|
if hasattr(context, 'user_api_key') else None,
|
||||||
|
|
||||||
|
@ -541,6 +555,9 @@ async def step_oai_chat_completions(context):
|
||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
response_format=context.response_format
|
response_format=context.response_format
|
||||||
if hasattr(context, 'response_format') else None,
|
if hasattr(context, 'response_format') else None,
|
||||||
|
tools=context.tools
|
||||||
|
if hasattr(context, 'tools') else None,
|
||||||
|
tool_choice=context.tool_choice,
|
||||||
user_api_key=context.user_api_key
|
user_api_key=context.user_api_key
|
||||||
if hasattr(context, 'user_api_key') else None)
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
@ -554,16 +571,18 @@ async def step_oai_chat_completions(context):
|
||||||
context.base_url,
|
context.base_url,
|
||||||
'/chat/completions',
|
'/chat/completions',
|
||||||
True, # async_client
|
True, # async_client
|
||||||
model=context.model
|
model=context.model,
|
||||||
if hasattr(context, 'model') else None,
|
# if hasattr(context, 'model') else None,
|
||||||
n_predict=context.n_predict
|
n_predict=context.n_predict,
|
||||||
if hasattr(context, 'n_predict') else None,
|
# if hasattr(context, 'n_predict') else None,
|
||||||
enable_streaming=context.enable_streaming
|
enable_streaming=context.enable_streaming
|
||||||
if hasattr(context, 'enable_streaming') else None,
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
response_format=context.response_format
|
response_format=context.response_format,
|
||||||
if hasattr(context, 'response_format') else None,
|
# if hasattr(context, 'response_format') else None,
|
||||||
user_api_key=context.user_api_key
|
tools=context.tools,# if hasattr(context, 'tools') else None,
|
||||||
if hasattr(context, 'user_api_key') else None)
|
tool_choice=context.tool_choice, # if hasattr(context, 'tool_choice') else None,
|
||||||
|
user_api_key=context.user_api_key)
|
||||||
|
# if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
@step('all prompts are predicted')
|
@step('all prompts are predicted')
|
||||||
|
@ -908,6 +927,8 @@ async def oai_chat_completions(user_prompt,
|
||||||
n_predict=None,
|
n_predict=None,
|
||||||
enable_streaming=None,
|
enable_streaming=None,
|
||||||
response_format=None,
|
response_format=None,
|
||||||
|
tools=None,
|
||||||
|
tool_choice=None,
|
||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
expect_api_error=None):
|
expect_api_error=None):
|
||||||
if debug:
|
if debug:
|
||||||
|
@ -935,6 +956,10 @@ async def oai_chat_completions(user_prompt,
|
||||||
}
|
}
|
||||||
if response_format is not None:
|
if response_format is not None:
|
||||||
payload['response_format'] = response_format
|
payload['response_format'] = response_format
|
||||||
|
if tools is not None:
|
||||||
|
payload['tools'] = tools
|
||||||
|
if tool_choice is not None:
|
||||||
|
payload['tool_choice'] = tool_choice
|
||||||
completion_response = {
|
completion_response = {
|
||||||
'content': '',
|
'content': '',
|
||||||
'timings': {
|
'timings': {
|
||||||
|
@ -996,6 +1021,8 @@ async def oai_chat_completions(user_prompt,
|
||||||
max_tokens=n_predict,
|
max_tokens=n_predict,
|
||||||
stream=enable_streaming,
|
stream=enable_streaming,
|
||||||
response_format=payload.get('response_format'),
|
response_format=payload.get('response_format'),
|
||||||
|
tools=payload.get('tools'),
|
||||||
|
tool_choice=payload.get('tool_choice'),
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=payload['temperature']
|
temperature=payload['temperature']
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
|
||||||
|
|
||||||
|
@ -410,19 +411,40 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
} else if (body.contains("tools") && body["tools"].is_array()) {
|
} else if (body.contains("tools") && body["tools"].is_array()) {
|
||||||
const auto & tools = body["tools"];
|
const auto & tools = body["tools"];
|
||||||
llama_params["grammar"] = tool_call_grammar(tools);
|
bool built_grammar = false;
|
||||||
|
bool allow_parallel_calls = false;
|
||||||
|
bool allow_content = true;
|
||||||
|
if (body.contains("tool_choice") && body["tool_choice"].is_string() && body["tool_choice"] != "auto") {
|
||||||
|
std::string tool_choice = body["tool_choice"];
|
||||||
|
if (tool_choice == "required") {
|
||||||
|
allow_content = false;
|
||||||
|
} else {
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (tool["name"] == tool_choice) {
|
||||||
|
llama_params["grammar"] = tool_call_grammar(json::array({ tool }), allow_parallel_calls, /* allow_content= */ false);
|
||||||
|
built_grammar = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!built_grammar) {
|
||||||
|
llama_params["grammar"] = tool_call_grammar(tools, allow_parallel_calls, allow_content);
|
||||||
|
}
|
||||||
|
// TODO: pass a template file.
|
||||||
extra_system_message = (std::stringstream()
|
extra_system_message = (std::stringstream()
|
||||||
<< "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. "
|
<< "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. "
|
||||||
<< "You may call one or more functions to assist with the user query. "
|
<< "You may call one or more functions to assist with the user query. "
|
||||||
<< "Don't make assumptions about what values to plug into functions. "
|
// << "Don't make assumptions about what values to plug into functions. "
|
||||||
<< "Here are the available tools: <tools>"
|
<< "Here are the available tools: <tools>"
|
||||||
<< tools.dump().c_str()
|
<< tools.dump(2).c_str()
|
||||||
<< "</tools>\n"
|
<< "</tools>\n"
|
||||||
|
// << "To call a tool give a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
|
||||||
<< "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
|
<< "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
|
||||||
<< "<tool_call>"
|
<< "<tool_call>"
|
||||||
<< "{\"arguments\": <args-dict>, \"name\": <function-name>}"
|
<< "{\"name\": <function-name>, \"arguments\": <args-dict>}"
|
||||||
<< "</tool_call>"
|
<< "</tool_call>"
|
||||||
|
<< "Don't explain which tools you're going to call, just call them."
|
||||||
).str();
|
).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -451,7 +473,7 @@ static json oaicompat_completion_params_parse(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Params supported by OAI but unsupported by llama.cpp
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
static const std::vector<std::string> unsupported_params;// { "tool_choice" };
|
||||||
for (auto & param : unsupported_params) {
|
for (auto & param : unsupported_params) {
|
||||||
if (body.contains(param)) {
|
if (body.contains(param)) {
|
||||||
throw std::runtime_error("Unsupported param: " + param);
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
@ -478,10 +500,36 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
|
||||||
|
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = "stop";
|
||||||
}
|
}
|
||||||
|
json tool_calls;
|
||||||
|
json message_content;
|
||||||
|
if (request.contains("tools")) {
|
||||||
|
std::regex pattern("<tool_call>(.*?)</tool_call>");
|
||||||
|
std::sregex_iterator iter(content.begin(), content.end(), pattern);
|
||||||
|
std::sregex_iterator end;
|
||||||
|
while (iter != end) {
|
||||||
|
std::smatch match = *iter;
|
||||||
|
auto call = json::parse(match[1].str());
|
||||||
|
if (tool_calls.is_null()) {
|
||||||
|
tool_calls = json::array();
|
||||||
|
}
|
||||||
|
tool_calls.push_back({
|
||||||
|
{"function", {
|
||||||
|
{"name", call["name"]},
|
||||||
|
{"arguments", call["arguments"].dump()},
|
||||||
|
}},
|
||||||
|
});
|
||||||
|
finish_reason = "tool_calls";
|
||||||
|
++iter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tool_calls.is_null()) {
|
||||||
|
message_content = content;
|
||||||
|
}
|
||||||
|
|
||||||
json choices =
|
json choices =
|
||||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
streaming ? json::array({json{{"finish_reason", finish_reason},
|
||||||
|
@ -489,7 +537,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
|
||||||
{"delta", json::object()}}})
|
{"delta", json::object()}}})
|
||||||
: json::array({json{{"finish_reason", finish_reason},
|
: json::array({json{{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"message", json{{"content", content},
|
{"message", json{{"content", message_content},
|
||||||
|
{"tool_calls", tool_calls},
|
||||||
{"role", "assistant"}}}}});
|
{"role", "assistant"}}}}});
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue