diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 4d426b6bd..1f47e313e 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: - friend std::string build_grammar(const std::function & cb); + friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; std::map _rules; @@ -764,10 +764,11 @@ private: public: SchemaConverter( const std::function & fetch_json, - bool dotall) + bool dotall, + bool compact_spaces) : _fetch_json(fetch_json), _dotall(dotall) { - _rules["space"] = SPACE_RULE; + _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; } void resolve_refs(json & schema, const std::string & url) { @@ -991,16 +992,16 @@ public: }; std::string json_schema_to_grammar(const json & schema) { - return build_grammar([&](const llama_grammar_builder & callbacks) { + return build_grammar([&](const common_grammar_builder & callbacks) { auto copy = schema; callbacks.resolve_refs(copy); callbacks.add_schema("", copy); }); } -std::string build_grammar(const std::function & cb) { - SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false); - llama_grammar_builder builder { +std::string build_grammar(const std::function & cb, const common_grammar_options & options) { + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces); + common_grammar_builder builder { /* .add_rule = */ [&](const std::string & name, const std::string & rule) { return converter._add_rule(name, rule); }, diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 4f43ab3a5..ba4112cb9 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -7,10 +7,15 @@ std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); -struct llama_grammar_builder { +struct common_grammar_builder { std::function add_rule; std::function add_schema; std::function resolve_refs; }; -std::string build_grammar(const std::function & cb); +struct common_grammar_options { + bool dotall = false; + bool compact_spaces = false; +}; + +std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/common/tool-call.cpp b/common/tool-call.cpp index a2704b5b8..01fce7e10 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -412,6 +412,10 @@ common_tool_call_handler common_tool_call_handler_init( const nlohmann::ordered_json & tools, const nlohmann::ordered_json & json_schema) { + common_grammar_options grammar_options { + /* .dotall = */ false, + /* .compact_spaces = */ true, + }; common_tool_call_handler handler; auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get(); @@ -489,9 +493,9 @@ common_tool_call_handler common_tool_call_handler_init( })} } : tool_call; - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { builder.add_schema("root", schema); - }); + }, grammar_options); // TODO: add schema to system prompt. auto tweaked_messages = add_system( messages, @@ -501,7 +505,7 @@ common_tool_call_handler common_tool_call_handler_init( } case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: { auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); for (const auto & tool : actual_tools) { const auto & function = tool["function"]; @@ -533,7 +537,7 @@ common_tool_call_handler common_tool_call_handler_init( schema["maxItems"] = 1; } builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); - }); + }, grammar_options); if (allow_content) { handler.grammar_triggers.push_back("[TOOL_CALLS]"); } @@ -542,7 +546,7 @@ common_tool_call_handler common_tool_call_handler_init( } case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: { auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); for (const auto & tool : actual_tools) { const auto & function = tool["function"]; @@ -567,7 +571,7 @@ common_tool_call_handler common_tool_call_handler_init( schema["maxItems"] = 1; } builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); - }); + }, grammar_options); if (allow_content) { handler.grammar_triggers.push_back(" functools["); } @@ -596,7 +600,7 @@ common_tool_call_handler common_tool_call_handler_init( // TODO: make this conditional on a very small model (e.g. 1B / 3B). auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2; - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; for (const auto & tool : actual_tools) { @@ -638,7 +642,7 @@ common_tool_call_handler common_tool_call_handler_init( } builder.add_rule("root", string_join(tool_rules, " | ")); - }); + }, grammar_options); handler.additional_stops.push_back("<|eom_id|>"); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, { {"builtin_tools", builtin_tools}, @@ -649,7 +653,7 @@ common_tool_call_handler common_tool_call_handler_init( // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; for (const auto & tool : actual_tools) { @@ -671,7 +675,7 @@ common_tool_call_handler common_tool_call_handler_init( } else { builder.add_rule("root", first_rule); } - }); + }, grammar_options); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); // handler.parser = parse_functionary_3_2_tool_calls; break; @@ -681,7 +685,7 @@ common_tool_call_handler common_tool_call_handler_init( // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // TODO: handle tool {type: code_interpreter} as python auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; for (const auto & tool : actual_tools) { const auto & function = tool["function"]; @@ -701,7 +705,7 @@ common_tool_call_handler common_tool_call_handler_init( if (allow_content) { handler.grammar_triggers.push_back("{"name": "foo", "arguments": {"a": 1}})* auto actual_tools = normalize_tools(tools); - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + handler.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; for (const auto & tool : actual_tools) { const auto & function = tool["function"]; @@ -732,7 +736,7 @@ common_tool_call_handler common_tool_call_handler_init( if (allow_content) { handler.grammar_triggers.push_back(""); } - }); + }, grammar_options); handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true); break; }