From c8458fa5f71c87a9f565d6adafafcedad4575e88 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 22 May 2024 03:51:20 +0100 Subject: [PATCH] openai: make content optional for tool call grammar gen --- common/json-schema-to-grammar.cpp | 23 ++++++++++++++--------- common/json-schema-to-grammar.h | 2 +- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 8ed8c85ec..b4a838c54 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -829,7 +829,7 @@ std::string json_schema_to_grammar(const json & schema) { return converter.format_grammar(); } -std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) { +std::string tool_call_grammar(const json & tools, bool allow_parallel_calls, bool allow_content) { SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); std::vector tool_rules; @@ -837,7 +837,7 @@ std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) { for (const auto & tool : tools) { const auto & function = tool["function"]; std::string name = function["name"]; - std::string description = function["description"]; + std::string description = function.contains("description") ? function["description"] : ""; auto parameters_copy = function["parameters"]; converter.resolve_refs(parameters_copy, name); @@ -854,13 +854,18 @@ std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) { converter.add_rule( "root", - converter.not_literal("") + " " + - converter.add_rule( - "tool_call", - "\"\" (" - + join(tool_rules.begin(), tool_rules.end(), " | ") - + ") \"\"" - ) + (allow_parallel_calls ? "*" : "?")); + (allow_content ? converter.not_literal("") + " | " : "") + + build_repetition( + converter.add_rule( + "tool_call", + "\"\" (" + + join(tool_rules.begin(), tool_rules.end(), " | ") + + ") \"\"" + ), + allow_content ? 0 : 1, + allow_parallel_calls ? std::numeric_limits::max() : 1, + " \"\\n\" " + )); converter.check_errors(); return converter.format_grammar(); diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 77e66cb2c..e0219cece 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,5 +5,5 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false); +std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false, bool allow_content = true); std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);