server: update tool calling, introduce system prompt for json schema
This commit is contained in:
parent
2b2127c2a3
commit
e41b6ceee9
3 changed files with 110 additions and 59 deletions
|
@ -200,28 +200,28 @@ static std::string format_literal(const std::string & literal) {
|
|||
return "\"" + escaped + "\"";
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
not_literal('a') -> '[^a]'
|
||||
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||
*/
|
||||
static std::string not_literal(const std::string & literal, bool dotall = true) {
|
||||
assert(literal.size() > 0);
|
||||
std::stringstream out;
|
||||
std::function<void(int)> recurse = [&](size_t i) {
|
||||
const auto & c = literal[i];
|
||||
out << "[^" << c << "]";
|
||||
if (i < literal.size() - 1) {
|
||||
out << " | " << format_literal(std::to_string(c)) << " (";
|
||||
recurse(i + 1);
|
||||
out << ")?";
|
||||
}
|
||||
};
|
||||
out << "(";
|
||||
recurse(0);
|
||||
out << ")" << (dotall ? DOTALL : DOT) << "*";
|
||||
return out.str();
|
||||
}
|
||||
// static std::string not_literal(const std::string & literal, bool dotall = true) {
|
||||
// assert(literal.size() > 0);
|
||||
// std::stringstream out;
|
||||
// std::function<void(int)> recurse = [&](size_t i) {
|
||||
// const char & c = literal[i];
|
||||
// out << "[^" << c << "]";
|
||||
// out << " " << (dotall ? DOTALL : DOT) << "*";
|
||||
// if (i < literal.size() - 1) {
|
||||
// out << " | " << format_literal(literal.substr(i, 1)) << " (";
|
||||
// recurse(i + 1);
|
||||
// out << ")?";
|
||||
// }
|
||||
// };
|
||||
// out << "(";
|
||||
// recurse(0);
|
||||
// out << ")";
|
||||
// return out.str();
|
||||
// }
|
||||
|
||||
|
||||
class SchemaConverter {
|
||||
|
@ -625,17 +625,57 @@ public:
|
|||
visit_refs(schema);
|
||||
}
|
||||
|
||||
std::string add_rule(const std::string & name, const std::string & rule) {
|
||||
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
|
||||
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
||||
_rules[esc_name] = rule;
|
||||
return esc_name;
|
||||
} else {
|
||||
/*
|
||||
reply ::= prefix tool-call*
|
||||
|
||||
prefix ::= [^<] prefix
|
||||
| "<" [^t] prefix
|
||||
| "<t" [^o] prefix
|
||||
| "<to" [^o] prefix
|
||||
| "<too" [^l] prefix
|
||||
| "<tool" [^_] prefix
|
||||
| "<tool_" [^c] prefix
|
||||
| "<tool_c" [^a] prefix
|
||||
| "<tool_ca" [^l] prefix
|
||||
| "<tool_cal" [^l] prefix
|
||||
| "<tool_call" [^l] prefix
|
||||
| "<tool_call" [^>] prefix
|
||||
|
|
||||
|
||||
*/
|
||||
|
||||
std::string not_literal(const std::string & literal) {
|
||||
auto rule_name = _find_rule_name("not" + literal, "!!!");
|
||||
std::stringstream out;
|
||||
for (size_t i = 0, n = literal.size(); i < n; i++) {
|
||||
out << " | ";
|
||||
if (i > 0) {
|
||||
out << format_literal(literal.substr(0, i)) << " ";
|
||||
}
|
||||
out << "[^" << literal[i] << "] " << rule_name.c_str();
|
||||
}
|
||||
_rules[rule_name] = out.str();
|
||||
return rule_name;
|
||||
}
|
||||
|
||||
std::string _escape_name(const std::string & name) {
|
||||
return regex_replace(name, INVALID_RULE_CHARS_RE, "-");
|
||||
}
|
||||
std::string _find_rule_name(const std::string & name, const std::string & rule) {
|
||||
auto esc_name = _escape_name(name);
|
||||
int i = 0;
|
||||
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
||||
i++;
|
||||
}
|
||||
std::string key = esc_name + std::to_string(i);
|
||||
return esc_name + std::to_string(i);
|
||||
}
|
||||
std::string add_rule(const std::string & name, const std::string & rule) {
|
||||
auto esc_name = _escape_name(name);
|
||||
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
||||
_rules[esc_name] = rule;
|
||||
return esc_name;
|
||||
} else {
|
||||
auto key = _find_rule_name(esc_name, rule);
|
||||
_rules[key] = rule;
|
||||
return key;
|
||||
}
|
||||
|
@ -789,7 +829,7 @@ std::string json_schema_to_grammar(const json & schema) {
|
|||
return converter.format_grammar();
|
||||
}
|
||||
|
||||
std::string tool_call_grammar(const json & tools) {
|
||||
std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) {
|
||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
||||
|
||||
std::vector<std::string> tool_rules;
|
||||
|
@ -814,12 +854,13 @@ std::string tool_call_grammar(const json & tools) {
|
|||
|
||||
converter.add_rule(
|
||||
"root",
|
||||
not_literal("<tool_call>") + " | "
|
||||
+ converter.add_rule(
|
||||
converter.not_literal("<tool_call>") + " " +
|
||||
converter.add_rule(
|
||||
"tool_call",
|
||||
"\"<tool_call>\" "
|
||||
"\"<tool_call>\" ("
|
||||
+ join(tool_rules.begin(), tool_rules.end(), " | ")
|
||||
+ " \"</tool_call>\""));
|
||||
+ ") \"</tool_call>\""
|
||||
) + (allow_parallel_calls ? "*" : "?"));
|
||||
|
||||
converter.check_errors();
|
||||
return converter.format_grammar();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#pragma once
|
||||
#include "json.hpp"
|
||||
|
||||
std::string tool_call_grammar(const nlohmann::ordered_json & tools);
|
||||
std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false);
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
||||
|
|
|
@ -123,7 +123,7 @@ inline bool verify_custom_template(const std::string & tmpl) {
|
|||
}
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages, const std::string & tools_tag) {
|
||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages, const std::string & extra_system_message) {
|
||||
size_t alloc_size = 0;
|
||||
// vector holding all allocated string to be passed to llama_chat_apply_template
|
||||
std::vector<std::string> str(messages.size() * 2);
|
||||
|
@ -138,18 +138,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|||
chat[i].content = str[i*2 + 1].c_str();
|
||||
}
|
||||
|
||||
if (!tools_tag.empty()) {
|
||||
alloc_size += tools_tag.size();
|
||||
if (chat.empty()) {
|
||||
str.resize(2);
|
||||
str[0] = "user";
|
||||
str[1] = tools_tag;
|
||||
chat.push_back({str[0].c_str(), str[1].c_str()});
|
||||
} else {
|
||||
auto & content = str[str.size() - 1];
|
||||
content += tools_tag;
|
||||
chat[chat.size() - 1].content = content.c_str();
|
||||
}
|
||||
if (!extra_system_message.empty()) {
|
||||
alloc_size += extra_system_message.size();
|
||||
|
||||
llama_chat_message msg { "system", extra_system_message.c_str() };
|
||||
chat.insert(chat.begin(), msg);
|
||||
// chat.push_back(msg);
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
|
||||
|
@ -387,22 +381,7 @@ static json oaicompat_completion_params_parse(
|
|||
llama_params["temperature"] = json_value(body, "temperature", 0.0);
|
||||
llama_params["top_p"] = json_value(body, "top_p", 1.0);
|
||||
|
||||
std::string tools_tag;
|
||||
if (body.contains("tools") && body["tools"].is_array()) {
|
||||
const auto & tools = body["tools"];
|
||||
llama_params["grammar"] = tool_call_grammar(tools);
|
||||
tools_tag = (std::stringstream() << "\n\n<tools>" << tools.dump(2) << "</tools>").str();
|
||||
}
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"], tools_tag);
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body["stop"].is_string()) {
|
||||
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
||||
} else {
|
||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
std::string extra_system_message;
|
||||
|
||||
// Handle "response_format" field
|
||||
if (body.contains("response_format")) {
|
||||
|
@ -410,9 +389,40 @@ static json oaicompat_completion_params_parse(
|
|||
std::string response_type = json_value(response_format, "type", std::string());
|
||||
if (response_type == "json_object") {
|
||||
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
|
||||
extra_system_message = (std::stringstream()
|
||||
<< "You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:\n<schema>\n"
|
||||
<< llama_params["json_schema"].dump().c_str()
|
||||
<< "\n</schema>"
|
||||
).str();
|
||||
} else if (!response_type.empty() && response_type != "text") {
|
||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||
}
|
||||
} else if (body.contains("tools") && body["tools"].is_array()) {
|
||||
const auto & tools = body["tools"];
|
||||
llama_params["grammar"] = tool_call_grammar(tools);
|
||||
|
||||
extra_system_message = (std::stringstream()
|
||||
<< "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. "
|
||||
<< "You may call one or more functions to assist with the user query. "
|
||||
<< "Don't make assumptions about what values to plug into functions. "
|
||||
<< "Here are the available tools: <tools>"
|
||||
<< tools.dump().c_str()
|
||||
<< "</tools>\n"
|
||||
<< "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:"
|
||||
<< "<tool_call>"
|
||||
<< "{\"arguments\": <args-dict>, \"name\": <function-name>}"
|
||||
<< "</tool_call>"
|
||||
).str();
|
||||
}
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body["messages"], extra_system_message);
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body["stop"].is_string()) {
|
||||
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
||||
} else {
|
||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue