Add grammar options + rename builder to common_grammar_builder
This commit is contained in:
parent
cdfa8b9d4f
commit
a46de6a03a
3 changed files with 33 additions and 23 deletions
|
@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) {
|
||||||
|
|
||||||
class SchemaConverter {
|
class SchemaConverter {
|
||||||
private:
|
private:
|
||||||
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||||
std::function<json(const std::string &)> _fetch_json;
|
std::function<json(const std::string &)> _fetch_json;
|
||||||
bool _dotall;
|
bool _dotall;
|
||||||
std::map<std::string, std::string> _rules;
|
std::map<std::string, std::string> _rules;
|
||||||
|
@ -764,10 +764,11 @@ private:
|
||||||
public:
|
public:
|
||||||
SchemaConverter(
|
SchemaConverter(
|
||||||
const std::function<json(const std::string &)> & fetch_json,
|
const std::function<json(const std::string &)> & fetch_json,
|
||||||
bool dotall)
|
bool dotall,
|
||||||
|
bool compact_spaces)
|
||||||
: _fetch_json(fetch_json), _dotall(dotall)
|
: _fetch_json(fetch_json), _dotall(dotall)
|
||||||
{
|
{
|
||||||
_rules["space"] = SPACE_RULE;
|
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||||
}
|
}
|
||||||
|
|
||||||
void resolve_refs(json & schema, const std::string & url) {
|
void resolve_refs(json & schema, const std::string & url) {
|
||||||
|
@ -991,16 +992,16 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema) {
|
std::string json_schema_to_grammar(const json & schema) {
|
||||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
return build_grammar([&](const common_grammar_builder & callbacks) {
|
||||||
auto copy = schema;
|
auto copy = schema;
|
||||||
callbacks.resolve_refs(copy);
|
callbacks.resolve_refs(copy);
|
||||||
callbacks.add_schema("", copy);
|
callbacks.add_schema("", copy);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||||
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
|
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||||
llama_grammar_builder builder {
|
common_grammar_builder builder {
|
||||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||||
return converter._add_rule(name, rule);
|
return converter._add_rule(name, rule);
|
||||||
},
|
},
|
||||||
|
|
|
@ -7,10 +7,15 @@
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
||||||
|
|
||||||
struct llama_grammar_builder {
|
struct common_grammar_builder {
|
||||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||||
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
struct common_grammar_options {
|
||||||
|
bool dotall = false;
|
||||||
|
bool compact_spaces = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||||
|
|
|
@ -412,6 +412,10 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
const nlohmann::ordered_json & tools,
|
const nlohmann::ordered_json & tools,
|
||||||
const nlohmann::ordered_json & json_schema)
|
const nlohmann::ordered_json & json_schema)
|
||||||
{
|
{
|
||||||
|
common_grammar_options grammar_options {
|
||||||
|
/* .dotall = */ false,
|
||||||
|
/* .compact_spaces = */ true,
|
||||||
|
};
|
||||||
common_tool_call_handler handler;
|
common_tool_call_handler handler;
|
||||||
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
|
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
|
||||||
|
|
||||||
|
@ -489,9 +493,9 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
})}
|
})}
|
||||||
}
|
}
|
||||||
: tool_call;
|
: tool_call;
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
builder.add_schema("root", schema);
|
builder.add_schema("root", schema);
|
||||||
});
|
}, grammar_options);
|
||||||
// TODO: add schema to system prompt.
|
// TODO: add schema to system prompt.
|
||||||
auto tweaked_messages = add_system(
|
auto tweaked_messages = add_system(
|
||||||
messages,
|
messages,
|
||||||
|
@ -501,7 +505,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
|
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
|
||||||
auto actual_tools = normalize_tools(tools);
|
auto actual_tools = normalize_tools(tools);
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
for (const auto & tool : actual_tools) {
|
for (const auto & tool : actual_tools) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
|
@ -533,7 +537,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
schema["maxItems"] = 1;
|
schema["maxItems"] = 1;
|
||||||
}
|
}
|
||||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||||
});
|
}, grammar_options);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_triggers.push_back("[TOOL_CALLS]");
|
handler.grammar_triggers.push_back("[TOOL_CALLS]");
|
||||||
}
|
}
|
||||||
|
@ -542,7 +546,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
|
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
|
||||||
auto actual_tools = normalize_tools(tools);
|
auto actual_tools = normalize_tools(tools);
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
for (const auto & tool : actual_tools) {
|
for (const auto & tool : actual_tools) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
|
@ -567,7 +571,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
schema["maxItems"] = 1;
|
schema["maxItems"] = 1;
|
||||||
}
|
}
|
||||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||||
});
|
}, grammar_options);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_triggers.push_back(" functools[");
|
handler.grammar_triggers.push_back(" functools[");
|
||||||
}
|
}
|
||||||
|
@ -596,7 +600,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||||
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||||
|
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
for (const auto & tool : actual_tools) {
|
for (const auto & tool : actual_tools) {
|
||||||
|
@ -638,7 +642,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
});
|
}, grammar_options);
|
||||||
handler.additional_stops.push_back("<|eom_id|>");
|
handler.additional_stops.push_back("<|eom_id|>");
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
|
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
|
||||||
{"builtin_tools", builtin_tools},
|
{"builtin_tools", builtin_tools},
|
||||||
|
@ -649,7 +653,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
auto actual_tools = normalize_tools(tools);
|
auto actual_tools = normalize_tools(tools);
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> first_tool_rules;
|
std::vector<std::string> first_tool_rules;
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
for (const auto & tool : actual_tools) {
|
for (const auto & tool : actual_tools) {
|
||||||
|
@ -671,7 +675,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
} else {
|
} else {
|
||||||
builder.add_rule("root", first_rule);
|
builder.add_rule("root", first_rule);
|
||||||
}
|
}
|
||||||
});
|
}, grammar_options);
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
break;
|
break;
|
||||||
|
@ -681,7 +685,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
// TODO: handle tool {type: code_interpreter} as python
|
// TODO: handle tool {type: code_interpreter} as python
|
||||||
auto actual_tools = normalize_tools(tools);
|
auto actual_tools = normalize_tools(tools);
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
for (const auto & tool : actual_tools) {
|
for (const auto & tool : actual_tools) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
|
@ -701,7 +705,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_triggers.push_back("<function=");
|
handler.grammar_triggers.push_back("<function=");
|
||||||
}
|
}
|
||||||
});
|
}, grammar_options);
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
break;
|
break;
|
||||||
|
@ -710,7 +714,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
// NousResearchHermesPro_2
|
// NousResearchHermesPro_2
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
auto actual_tools = normalize_tools(tools);
|
auto actual_tools = normalize_tools(tools);
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
for (const auto & tool : actual_tools) {
|
for (const auto & tool : actual_tools) {
|
||||||
const auto & function = tool["function"];
|
const auto & function = tool["function"];
|
||||||
|
@ -732,7 +736,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_triggers.push_back("<tool_call>");
|
handler.grammar_triggers.push_back("<tool_call>");
|
||||||
}
|
}
|
||||||
});
|
}, grammar_options);
|
||||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue