tool-calls: fix grammar regression

This commit is contained in:
Olivier Chafik 2024-10-22 11:49:50 +01:00
parent b53362a148
commit 7f2429e6b0

View file

@ -378,7 +378,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
handler.grammar_trigger_words.push_back("{\n \"");
}
builder.add_rule("", join(tool_rules.begin(), tool_rules.end(), " | "));
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
});
handler.additional_stop_words.push_back("<|eom_id|>");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, {
@ -408,9 +408,9 @@ llama_tool_call_handler llama_tool_call_handler_init(
auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space";
if (parallel_tool_calls) {
auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space";
builder.add_rule("", first_rule + " (" + subsequent_rule + ")*");
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
} else {
builder.add_rule("", first_rule);
builder.add_rule("root", first_rule);
}
});
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
@ -438,7 +438,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<function=");
}
@ -468,7 +468,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
}
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
builder.add_rule("", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<tool_call>");
}