openai: make content optional for tool call grammar gen

This commit is contained in:
ochafik 2024-05-22 03:51:20 +01:00
parent 6dadcd2519
commit c8458fa5f7
2 changed files with 15 additions and 10 deletions

View file

@ -829,7 +829,7 @@ std::string json_schema_to_grammar(const json & schema) {
return converter.format_grammar(); return converter.format_grammar();
} }
std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) { std::string tool_call_grammar(const json & tools, bool allow_parallel_calls, bool allow_content) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
@ -837,7 +837,7 @@ std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) {
for (const auto & tool : tools) { for (const auto & tool : tools) {
const auto & function = tool["function"]; const auto & function = tool["function"];
std::string name = function["name"]; std::string name = function["name"];
std::string description = function["description"]; std::string description = function.contains("description") ? function["description"] : "";
auto parameters_copy = function["parameters"]; auto parameters_copy = function["parameters"];
converter.resolve_refs(parameters_copy, name); converter.resolve_refs(parameters_copy, name);
@ -854,13 +854,18 @@ std::string tool_call_grammar(const json & tools, bool allow_parallel_calls) {
converter.add_rule( converter.add_rule(
"root", "root",
converter.not_literal("<tool_call>") + " " + (allow_content ? converter.not_literal("<tool_call>") + " | " : "") +
converter.add_rule( build_repetition(
"tool_call", converter.add_rule(
"\"<tool_call>\" (" "tool_call",
+ join(tool_rules.begin(), tool_rules.end(), " | ") "\"<tool_call>\" ("
+ ") \"</tool_call>\"" + join(tool_rules.begin(), tool_rules.end(), " | ")
) + (allow_parallel_calls ? "*" : "?")); + ") \"</tool_call>\""
),
allow_content ? 0 : 1,
allow_parallel_calls ? std::numeric_limits<int>::max() : 1,
" \"\\n\" "
));
converter.check_errors(); converter.check_errors();
return converter.format_grammar(); return converter.format_grammar();

View file

@ -5,5 +5,5 @@
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false); std::string tool_call_grammar(const nlohmann::ordered_json & tools, bool allow_parallel_calls = false, bool allow_content = true);
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);