openai: make content optional for tool call grammar gen

This commit is contained in:
ochafik 2024-05-22 03:51:20 +01:00
parent 6dadcd2519
commit c8458fa5f7
2 changed files with 15 additions and 10 deletions

View file

@ -829,7 +829,7 @@ std::string json_schema_to_grammar(const json & schema) {
return converter.format_grammar();
}
std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) {
std::string tool_call_grammar(const json & tools, bool allow_parallel_calls, bool allow_content) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
std::vector<std::string> tool_rules;
@ -837,7 +837,7 @@ std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) {
for (const auto & tool : tools) {
const auto & function = tool["function"];
std::string name = function["name"];
std::string description = function["description"];
std::string description = function.contains("description") ? function["description"] : "";
auto parameters_copy = function["parameters"];
converter.resolve_refs(parameters_copy, name);
@ -854,13 +854,18 @@ std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) {
converter.add_rule(
"root",
converter.not_literal("<tool_call>") + " " +
converter.add_rule(
"tool_call",
"\"<tool_call>\" ("
+ join(tool_rules.begin(), tool_rules.end(), " | ")
+ ") \"</tool_call>\""
) + (allow_parallel_calls ? "*" : "?"));
(allow_content ? converter.not_literal("<tool_call>") + " | " : "") +
build_repetition(
converter.add_rule(
"tool_call",
"\"<tool_call>\" ("
+ join(tool_rules.begin(), tool_rules.end(), " | ")
+ ") \"</tool_call>\""
),
allow_content ? 0 : 1,
allow_parallel_calls ? std::numeric_limits<int>::max() : 1,
" \"\\n\" "
));
converter.check_errors();
return converter.format_grammar();

View file

@ -5,5 +5,5 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false);
std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false, bool allow_content = true);
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);