Add grammar options + rename builder to common_grammar_builder

This commit is contained in:
Olivier Chafik 2025-01-22 18:36:04 +00:00
parent cdfa8b9d4f
commit a46de6a03a
3 changed files with 33 additions and 23 deletions

View file

@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) {
class SchemaConverter { class SchemaConverter {
private: private:
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb); friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json; std::function<json(const std::string &)> _fetch_json;
bool _dotall; bool _dotall;
std::map<std::string, std::string> _rules; std::map<std::string, std::string> _rules;
@ -764,10 +764,11 @@ private:
public: public:
SchemaConverter( SchemaConverter(
const std::function<json(const std::string &)> & fetch_json, const std::function<json(const std::string &)> & fetch_json,
bool dotall) bool dotall,
bool compact_spaces)
: _fetch_json(fetch_json), _dotall(dotall) : _fetch_json(fetch_json), _dotall(dotall)
{ {
_rules["space"] = SPACE_RULE; _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
} }
void resolve_refs(json & schema, const std::string & url) { void resolve_refs(json & schema, const std::string & url) {
@ -991,16 +992,16 @@ public:
}; };
std::string json_schema_to_grammar(const json & schema) { std::string json_schema_to_grammar(const json & schema) {
return build_grammar([&](const llama_grammar_builder & callbacks) { return build_grammar([&](const common_grammar_builder & callbacks) {
auto copy = schema; auto copy = schema;
callbacks.resolve_refs(copy); callbacks.resolve_refs(copy);
callbacks.add_schema("", copy); callbacks.add_schema("", copy);
}); });
} }
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) { std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false); SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
llama_grammar_builder builder { common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) { /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule); return converter._add_rule(name, rule);
}, },

View file

@ -7,10 +7,15 @@
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
struct llama_grammar_builder { struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule; std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema; std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
std::function<void(nlohmann::ordered_json &)> resolve_refs; std::function<void(nlohmann::ordered_json &)> resolve_refs;
}; };
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb); struct common_grammar_options {
bool dotall = false;
bool compact_spaces = false;
};
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View file

@ -412,6 +412,10 @@ common_tool_call_handler common_tool_call_handler_init(
const nlohmann::ordered_json & tools, const nlohmann::ordered_json & tools,
const nlohmann::ordered_json & json_schema) const nlohmann::ordered_json & json_schema)
{ {
common_grammar_options grammar_options {
/* .dotall = */ false,
/* .compact_spaces = */ true,
};
common_tool_call_handler handler; common_tool_call_handler handler;
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>(); auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
@ -489,9 +493,9 @@ common_tool_call_handler common_tool_call_handler_init(
})} })}
} }
: tool_call; : tool_call;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_schema("root", schema); builder.add_schema("root", schema);
}); }, grammar_options);
// TODO: add schema to system prompt. // TODO: add schema to system prompt.
auto tweaked_messages = add_system( auto tweaked_messages = add_system(
messages, messages,
@ -501,7 +505,7 @@ common_tool_call_handler common_tool_call_handler_init(
} }
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: { case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
@ -533,7 +537,7 @@ common_tool_call_handler common_tool_call_handler_init(
schema["maxItems"] = 1; schema["maxItems"] = 1;
} }
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
}); }, grammar_options);
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back("[TOOL_CALLS]"); handler.grammar_triggers.push_back("[TOOL_CALLS]");
} }
@ -542,7 +546,7 @@ common_tool_call_handler common_tool_call_handler_init(
} }
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: { case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
@ -567,7 +571,7 @@ common_tool_call_handler common_tool_call_handler_init(
schema["maxItems"] = 1; schema["maxItems"] = 1;
} }
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
}); }, grammar_options);
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back(" functools["); handler.grammar_triggers.push_back(" functools[");
} }
@ -596,7 +600,7 @@ common_tool_call_handler common_tool_call_handler_init(
// TODO: make this conditional on a very small model (e.g. 1B / 3B). // TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
@ -638,7 +642,7 @@ common_tool_call_handler common_tool_call_handler_init(
} }
builder.add_rule("root", string_join(tool_rules, " | ")); builder.add_rule("root", string_join(tool_rules, " | "));
}); }, grammar_options);
handler.additional_stops.push_back("<|eom_id|>"); handler.additional_stops.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, { handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
{"builtin_tools", builtin_tools}, {"builtin_tools", builtin_tools},
@ -649,7 +653,7 @@ common_tool_call_handler common_tool_call_handler_init(
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
@ -671,7 +675,7 @@ common_tool_call_handler common_tool_call_handler_init(
} else { } else {
builder.add_rule("root", first_rule); builder.add_rule("root", first_rule);
} }
}); }, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls; // handler.parser = parse_functionary_3_2_tool_calls;
break; break;
@ -681,7 +685,7 @@ common_tool_call_handler common_tool_call_handler_init(
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python // TODO: handle tool {type: code_interpreter} as python
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
@ -701,7 +705,7 @@ common_tool_call_handler common_tool_call_handler_init(
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back("<function="); handler.grammar_triggers.push_back("<function=");
} }
}); }, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
// handler.parser = parse_functionary_3_2_tool_calls; // handler.parser = parse_functionary_3_2_tool_calls;
break; break;
@ -710,7 +714,7 @@ common_tool_call_handler common_tool_call_handler_init(
// NousResearchHermesPro_2 // NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)* // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
auto actual_tools = normalize_tools(tools); auto actual_tools = normalize_tools(tools);
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
for (const auto & tool : actual_tools) { for (const auto & tool : actual_tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
@ -732,7 +736,7 @@ common_tool_call_handler common_tool_call_handler_init(
if (allow_content) { if (allow_content) {
handler.grammar_triggers.push_back("<tool_call>"); handler.grammar_triggers.push_back("<tool_call>");
} }
}); }, grammar_options);
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
break; break;
} }