tool-call: fix grammar roots

This commit is contained in:
ochafik 2024-10-22 10:50:52 +01:00
parent e753f15229
commit 75764871e6

View file

@ -293,7 +293,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("{\n \""); handler.grammar_trigger_words.push_back("{\n \"");
} }
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); builder.add_rule("", join(tool_rules.begin(), tool_rules.end(), " | "));
}); });
handler.additional_stop_words.push_back("<|eom_id|>"); handler.additional_stop_words.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, { handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, {
@ -323,9 +323,9 @@ llama_tool_call_handler llama_tool_call_handler_init(
auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space"; auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space";
if (parallel_tool_calls) { if (parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space"; auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space";
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); builder.add_rule("", first_rule + " (" + subsequent_rule + ")*");
} else { } else {
builder.add_rule("root", first_rule); builder.add_rule("", first_rule);
} }
}); });
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
@ -383,7 +383,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
} }
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space"; auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); builder.add_rule("", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) { if (allow_content) {
handler.grammar_trigger_words.push_back("<tool_call>"); handler.grammar_trigger_words.push_back("<tool_call>");
} }