Add grammar options + rename builder to common_grammar_builder
This commit is contained in:
parent
cdfa8b9d4f
commit
a46de6a03a
3 changed files with 33 additions and 23 deletions
|
@ -343,7 +343,7 @@ static std::string format_literal(const std::string & literal) {
|
|||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
std::map<std::string, std::string> _rules;
|
||||
|
@ -764,10 +764,11 @@ private:
|
|||
public:
|
||||
SchemaConverter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall)
|
||||
bool dotall,
|
||||
bool compact_spaces)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
{
|
||||
_rules["space"] = SPACE_RULE;
|
||||
_rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
|
||||
}
|
||||
|
||||
void resolve_refs(json & schema, const std::string & url) {
|
||||
|
@ -991,16 +992,16 @@ public:
|
|||
};
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||
return build_grammar([&](const common_grammar_builder & callbacks) {
|
||||
auto copy = schema;
|
||||
callbacks.resolve_refs(copy);
|
||||
callbacks.add_schema("", copy);
|
||||
});
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
|
||||
llama_grammar_builder builder {
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
|
||||
common_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
},
|
||||
|
|
|
@ -7,10 +7,15 @@
|
|||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
||||
|
||||
struct llama_grammar_builder {
|
||||
struct common_grammar_builder {
|
||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||
struct common_grammar_options {
|
||||
bool dotall = false;
|
||||
bool compact_spaces = false;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
|
|
@ -412,6 +412,10 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
const nlohmann::ordered_json & tools,
|
||||
const nlohmann::ordered_json & json_schema)
|
||||
{
|
||||
common_grammar_options grammar_options {
|
||||
/* .dotall = */ false,
|
||||
/* .compact_spaces = */ true,
|
||||
};
|
||||
common_tool_call_handler handler;
|
||||
auto parallel = parallel_tool_calls.is_null() ? tmpl.supports_parallel_tool_calls() : parallel_tool_calls.get<bool>();
|
||||
|
||||
|
@ -489,9 +493,9 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
})}
|
||||
}
|
||||
: tool_call;
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_schema("root", schema);
|
||||
});
|
||||
}, grammar_options);
|
||||
// TODO: add schema to system prompt.
|
||||
auto tweaked_messages = add_system(
|
||||
messages,
|
||||
|
@ -501,7 +505,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
}
|
||||
case COMMON_TOOL_CALL_STYLE_MISTRAL_NEMO: {
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
|
@ -533,7 +537,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
});
|
||||
}, grammar_options);
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("[TOOL_CALLS]");
|
||||
}
|
||||
|
@ -542,7 +546,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
}
|
||||
case COMMON_TOOL_CALL_STYLE_FIRE_FUNCTION_V2: {
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
|
@ -567,7 +571,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
||||
});
|
||||
}, grammar_options);
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back(" functools[");
|
||||
}
|
||||
|
@ -596,7 +600,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||
auto eagerly_match_any_json = style == common_tool_call_style::COMMON_TOOL_CALL_STYLE_LLAMA_3_2;
|
||||
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
for (const auto & tool : actual_tools) {
|
||||
|
@ -638,7 +642,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
}
|
||||
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
});
|
||||
}, grammar_options);
|
||||
handler.additional_stops.push_back("<|eom_id|>");
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
|
||||
{"builtin_tools", builtin_tools},
|
||||
|
@ -649,7 +653,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> first_tool_rules;
|
||||
std::vector<std::string> subsequent_tool_rules;
|
||||
for (const auto & tool : actual_tools) {
|
||||
|
@ -671,7 +675,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
} else {
|
||||
builder.add_rule("root", first_rule);
|
||||
}
|
||||
});
|
||||
}, grammar_options);
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
|
@ -681,7 +685,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
// TODO: handle tool {type: code_interpreter} as python
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
|
@ -701,7 +705,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("<function=");
|
||||
}
|
||||
});
|
||||
}, grammar_options);
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
|
@ -710,7 +714,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
// NousResearchHermesPro_2
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
auto actual_tools = normalize_tools(tools);
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
handler.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : actual_tools) {
|
||||
const auto & function = tool["function"];
|
||||
|
@ -732,7 +736,7 @@ common_tool_call_handler common_tool_call_handler_init(
|
|||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("<tool_call>");
|
||||
}
|
||||
});
|
||||
}, grammar_options);
|
||||
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true);
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue