server: tool call grammar-constraints

fix
This commit is contained in:
ochafik 2024-05-02 03:20:00 +01:00
parent 312e20b54a
commit ca1a640da2
4 changed files with 127 additions and 42 deletions

View file

@ -11,6 +11,9 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
const char * DOTALL = "[\\U00000000-\\U0010FFFF]";
const char * DOT = "[^\\x0A\\x0D]";
template <typename Iterator> template <typename Iterator>
static std::string join(Iterator begin, Iterator end, const std::string & separator); static std::string join(Iterator begin, Iterator end, const std::string & separator);
@ -198,6 +201,29 @@ static std::string format_literal(const std::string & literal) {
} }
/*
not_literal('a') -> '[^a]'
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
*/
static std::string not_literal(const std::string & literal, bool dotall = true) {
assert(literal.size() > 0);
std::stringstream out;
std::function<void(int)> recurse = [&](size_t i) {
const auto & c = literal[i];
out << "[^" << c << "]";
if (i < literal.size() - 1) {
out << " | " << format_literal(std::to_string(c)) << " (";
recurse(i + 1);
out << ")?";
}
};
out << "(";
recurse(0);
out << ")" << (dotall ? DOTALL : DOT) << "*";
return out.str();
}
class SchemaConverter { class SchemaConverter {
private: private:
std::function<json(const std::string &)> _fetch_json; std::function<json(const std::string &)> _fetch_json;
@ -208,22 +234,6 @@ private:
std::vector<std::string> _errors; std::vector<std::string> _errors;
std::vector<std::string> _warnings; std::vector<std::string> _warnings;
std::string _add_rule(const std::string & name, const std::string & rule) {
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
_rules[esc_name] = rule;
return esc_name;
} else {
int i = 0;
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
i++;
}
std::string key = esc_name + std::to_string(i);
_rules[key] = rule;
return key;
}
}
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) { std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
std::vector<std::string> rules; std::vector<std::string> rules;
for (size_t i = 0; i < alt_schemas.size(); i++) { for (size_t i = 0; i < alt_schemas.size(); i++) {
@ -256,11 +266,11 @@ private:
auto get_dot = [&]() { auto get_dot = [&]() {
std::string rule; std::string rule;
if (_dotall) { if (_dotall) {
rule = "[\\U00000000-\\U0010FFFF]"; rule = DOTALL;
} else { } else {
rule = "[^\\x0A\\x0D]"; rule = DOT;
} }
return _add_rule("dot", rule); return add_rule("dot", rule);
}; };
// Joins the sequence, merging consecutive literals together. // Joins the sequence, merging consecutive literals together.
@ -377,7 +387,7 @@ private:
if (!sub_is_literal) { if (!sub_is_literal) {
std::string & sub_id = sub_rule_ids[sub]; std::string & sub_id = sub_rule_ids[sub];
if (sub_id.empty()) { if (sub_id.empty()) {
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub); sub_id = add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
} }
sub = sub_id; sub = sub_id;
} }
@ -423,7 +433,7 @@ private:
} }
return join_seq(); return join_seq();
}; };
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); return add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
} }
std::string _resolve_ref(const std::string & ref) { std::string _resolve_ref(const std::string & ref) {
@ -451,7 +461,7 @@ private:
const auto &prop_schema = kv.second; const auto &prop_schema = kv.second;
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name); std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
prop_kv_rule_names[prop_name] = _add_rule( prop_kv_rule_names[prop_name] = add_rule(
name + (name.empty() ? "" : "-") + prop_name + "-kv", name + (name.empty() ? "" : "-") + prop_name + "-kv",
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
); );
@ -464,7 +474,7 @@ private:
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) { if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
std::string sub_name = name + (name.empty() ? "" : "-") + "additional"; std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value"); std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule); std::string kv_rule = add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule; prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*"); optional_props.push_back("*");
} }
@ -491,7 +501,7 @@ private:
std::string k = ks[0]; std::string k = ks[0];
std::string kv_rule_name = prop_kv_rule_names[k]; std::string kv_rule_name = prop_kv_rule_names[k];
if (k == "*") { if (k == "*") {
res = _add_rule( res = add_rule(
name + (name.empty() ? "" : "-") + "additional-kvs", name + (name.empty() ? "" : "-") + "additional-kvs",
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*" kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
); );
@ -501,7 +511,7 @@ private:
res = kv_rule_name; res = kv_rule_name;
} }
if (ks.size() > 1) { if (ks.size() > 1) {
res += " " + _add_rule( res += " " + add_rule(
name + (name.empty() ? "" : "-") + k + "-rest", name + (name.empty() ? "" : "-") + k + "-rest",
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true) get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
); );
@ -527,7 +537,7 @@ private:
} }
std::string _add_primitive(const std::string & name, const BuiltinRule & rule) { std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
auto n = _add_rule(name, rule.content); auto n = add_rule(name, rule.content);
for (const auto & dep : rule.deps) { for (const auto & dep : rule.deps) {
BuiltinRule dep_rule; BuiltinRule dep_rule;
auto it = PRIMITIVE_RULES.find(dep); auto it = PRIMITIVE_RULES.find(dep);
@ -615,6 +625,22 @@ public:
visit_refs(schema); visit_refs(schema);
} }
std::string add_rule(const std::string & name, const std::string & rule) {
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
_rules[esc_name] = rule;
return esc_name;
} else {
int i = 0;
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
i++;
}
std::string key = esc_name + std::to_string(i);
_rules[key] = rule;
return key;
}
}
std::string _generate_constant_rule(const json & value) { std::string _generate_constant_rule(const json & value) {
return format_literal(value.dump()); return format_literal(value.dump());
} }
@ -625,24 +651,24 @@ public:
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name; std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
if (schema.contains("$ref")) { if (schema.contains("$ref")) {
return _add_rule(rule_name, _resolve_ref(schema["$ref"])); return add_rule(rule_name, _resolve_ref(schema["$ref"]));
} else if (schema.contains("oneOf") || schema.contains("anyOf")) { } else if (schema.contains("oneOf") || schema.contains("anyOf")) {
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>(); std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas)); return add_rule(rule_name, _generate_union_rule(name, alt_schemas));
} else if (schema_type.is_array()) { } else if (schema_type.is_array()) {
std::vector<json> schema_types; std::vector<json> schema_types;
for (const auto & t : schema_type) { for (const auto & t : schema_type) {
schema_types.push_back({{"type", t}}); schema_types.push_back({{"type", t}});
} }
return _add_rule(rule_name, _generate_union_rule(name, schema_types)); return add_rule(rule_name, _generate_union_rule(name, schema_types));
} else if (schema.contains("const")) { } else if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"])); return add_rule(rule_name, _generate_constant_rule(schema["const"]));
} else if (schema.contains("enum")) { } else if (schema.contains("enum")) {
std::vector<std::string> enum_values; std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) { for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v)); enum_values.push_back(_generate_constant_rule(v));
} }
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | ")); return add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
} else if ((schema_type.is_null() || schema_type == "object") } else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") || && (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@ -660,7 +686,7 @@ public:
properties.emplace_back(prop.key(), prop.value()); properties.emplace_back(prop.key(), prop.value());
} }
} }
return _add_rule(rule_name, return add_rule(rule_name,
_build_object_rule( _build_object_rule(
properties, required, name, properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
@ -691,7 +717,7 @@ public:
add_component(t, true); add_component(t, true);
} }
} }
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); return add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
if (items.is_array()) { if (items.is_array()) {
@ -703,14 +729,14 @@ public:
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i)); rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
} }
rule += " \"]\" space"; rule += " \"]\" space";
return _add_rule(rule_name, rule); return add_rule(rule_name, rule);
} else { } else {
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0; int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max(); int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); return add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
} }
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name); return _visit_pattern(schema["pattern"], rule_name);
@ -718,14 +744,14 @@ public:
return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid")); return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
} else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) { } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
auto prim_name = schema_format + "-string"; auto prim_name = schema_format + "-string";
return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name))); return add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
} else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) { } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0; int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max(); int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); return add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
} else if (schema.empty() || schema_type == "object") { } else if (schema.empty() || schema_type == "object") {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); return add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
} else { } else {
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) { if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
_errors.push_back("Unrecognized schema: " + schema.dump()); _errors.push_back("Unrecognized schema: " + schema.dump());
@ -762,3 +788,39 @@ std::string json_schema_to_grammar(const json & schema) {
converter.check_errors(); converter.check_errors();
return converter.format_grammar(); return converter.format_grammar();
} }
std::string tool_call_grammar(const json & tools) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
const auto & function = tool["function"];
std::string name = function["name"];
std::string description = function["description"];
auto parameters_copy = function["parameters"];
converter.resolve_refs(parameters_copy, name);
tool_rules.push_back(converter.visit(json {
{"type", "object"},
{"description", description},
{"properties", json {
{"name", json {{"const", name}}},
{"arguments", parameters_copy},
}},
{"required", json::array({"name", "arguments"})},
}, name + "-tool-call"));
}
converter.add_rule(
"root",
not_literal("<tool_call>") + " | "
+ converter.add_rule(
"tool_call",
"\"<tool_call>\" "
+ join(tool_rules.begin(), tool_rules.end(), " | ")
+ " \"</tool_call>\""));
converter.check_errors();
return converter.format_grammar();
}

View file

@ -1,4 +1,5 @@
#pragma once #pragma once
#include "json.hpp" #include "json.hpp"
std::string tool_call_grammar(const nlohmann::ordered_json & tools);
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);

View file

@ -3031,7 +3031,7 @@ int main(int argc, char ** argv) {
chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
chat.push_back({{"role", "user"}, {"content", "How are you?"}}); chat.push_back({{"role", "user"}, {"content", "How are you?"}});
const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat, "");
LOG_INFO("chat template", { LOG_INFO("chat template", {
{"chat_example", chat_example}, {"chat_example", chat_example},

View file

@ -4,6 +4,7 @@
#include "common.h" #include "common.h"
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h"
#include <string> #include <string>
#include <vector> #include <vector>
@ -122,7 +123,7 @@ inline bool verify_custom_template(const std::string & tmpl) {
} }
// Format given chat. If tmpl is empty, we take the template from model metadata // Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) { inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages, const std::string & tools_tag) {
size_t alloc_size = 0; size_t alloc_size = 0;
// vector holding all allocated string to be passed to llama_chat_apply_template // vector holding all allocated string to be passed to llama_chat_apply_template
std::vector<std::string> str(messages.size() * 2); std::vector<std::string> str(messages.size() * 2);
@ -137,6 +138,20 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
chat[i].content = str[i*2 + 1].c_str(); chat[i].content = str[i*2 + 1].c_str();
} }
if (!tools_tag.empty()) {
alloc_size += tools_tag.size();
if (chat.empty()) {
str.resize(2);
str[0] = "user";
str[1] = tools_tag;
chat.push_back({str[0].c_str(), str[1].c_str()});
} else {
auto & content = str[str.size() - 1];
content += tools_tag;
chat[chat.size() - 1].content = content.c_str();
}
}
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2); std::vector<char> buf(alloc_size * 2);
@ -372,8 +387,15 @@ static json oaicompat_completion_params_parse(
llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_p"] = json_value(body, "top_p", 1.0); llama_params["top_p"] = json_value(body, "top_p", 1.0);
std::string tools_tag;
if (body.contains("tools") && body["tools"].is_array()) {
const auto & tools = body["tools"];
llama_params["grammar"] = tool_call_grammar(tools);
tools_tag = (std::stringstream() << "\n\n<tools>" << tools.dump(2) << "</tools>").str();
}
// Apply chat template to the list of messages // Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); llama_params["prompt"] = format_chat(model, chat_template, body["messages"], tools_tag);
// Handle "stop" field // Handle "stop" field
if (body.contains("stop") && body["stop"].is_string()) { if (body.contains("stop") && body["stop"].is_string()) {
@ -408,7 +430,7 @@ static json oaicompat_completion_params_parse(
} }
// Params supported by OAI but unsupported by llama.cpp // Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" }; static const std::vector<std::string> unsupported_params { "tool_choice" };
for (auto & param : unsupported_params) { for (auto & param : unsupported_params) {
if (body.contains(param)) { if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param); throw std::runtime_error("Unsupported param: " + param);