Add grammar options + rename builder to common_grammar_builder

This commit is contained in:
Olivier Chafik 2025-01-22 18:36:04 +00:00
parent cdfa8b9d4f
commit a46de6a03a
3 changed files with 33 additions and 23 deletions

View file

@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) {
class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
@ -764,10 +764,11 @@ private:
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
bool dotall,
bool compact_spaces)
: _fetch_json(fetch_json), _dotall(dotall)
{
_rules["space"] = SPACE_RULE;
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
}
void resolve_refs(json & schema, const std::string & url) {
@ -991,16 +992,16 @@ public:
};
std::string json_schema_to_grammar(const json & schema) {
return build_grammar([&](const llama_grammar_builder & callbacks) {
return build_grammar([&](const common_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
}
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
llama_grammar_builder builder {
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
},

View file

@ -7,10 +7,15 @@
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
struct llama_grammar_builder {
struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
std::function<void(nlohmann::ordered_json &)> resolve_refs;
};
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
struct common_grammar_options {
bool dotall = false;
bool compact_spaces = false;
};
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View file

@ -412,6 +412,10 @@ common_tool_call_handler common_tool_call_handler_init(
const nlohmann::ordered_json & tools,
const nlohmann::ordered_json & json_schema)
{
common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ true,
};
common_tool_call_handler handler;
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
@ -489,9 +493,9 @@ common_tool_call_handler common_tool_call_handler_init(
})}
}
: tool_call;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_schema("root", schema);
});
}, grammar_options);
// TODO: add schema to system prompt.
auto tweaked_messages = add_system(
messages,
@ -501,7 +505,7 @@ common_tool_call_handler common_tool_call_handler_init(
}
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
@ -533,7 +537,7 @@ common_tool_call_handler common_tool_call_handler_init(
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
});
}, grammar_options);
if (allow_content) {
handler.grammar_triggers.push_back("[TOOL_CALLS]");
}
@ -542,7 +546,7 @@ common_tool_call_handler common_tool_call_handler_init(
}
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
@ -567,7 +571,7 @@ common_tool_call_handler common_tool_call_handler_init(
schema["maxItems"] = 1;
}
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
});
}, grammar_options);
if (allow_content) {
handler.grammar_triggers.push_back(" functools[");
}
@ -596,7 +600,7 @@ common_tool_call_handler common_tool_call_handler_init(
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) {
@ -638,7 +642,7 @@ common_tool_call_handler common_tool_call_handler_init(
}
builder.add_rule("root", string_join(tool_rules, " | "));
});
}, grammar_options);
handler.additional_stops.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools},
@ -649,7 +653,7 @@ common_tool_call_handler common_tool_call_handler_init(
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules;
for (const auto & tool : actual_tools) {
@ -671,7 +675,7 @@ common_tool_call_handler common_tool_call_handler_init(
} else {
builder.add_rule("root", first_rule);
}
});
}, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls;
break;
@ -681,7 +685,7 @@ common_tool_call_handler common_tool_call_handler_init(
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
@ -701,7 +705,7 @@ common_tool_call_handler common_tool_call_handler_init(
if (allow_content) {
handler.grammar_triggers.push_back("<function=");
}
});
}, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls;
break;
@ -710,7 +714,7 @@ common_tool_call_handler common_tool_call_handler_init(
// NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) {
const auto & function = tool["function"];
@ -732,7 +736,7 @@ common_tool_call_handler common_tool_call_handler_init(
if (allow_content) {
handler.grammar_triggers.push_back("<tool_call>");
}
});
}, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break;
}