From 75764871e6c92484e30db673650eb8f76d895c2f Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 22 Oct 2024 10:50:52 +0100 Subject: [PATCH] `tool-call`: fix grammar roots --- common/tool-call.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 0880a610f..08cd57b1c 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -293,7 +293,7 @@ llama_tool_call_handler llama_tool_call_handler_init( handler.grammar_trigger_words.push_back("{\n \""); } - builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); + builder.add_rule("", join(tool_rules.begin(), tool_rules.end(), " | ")); }); handler.additional_stop_words.push_back("<|eom_id|>"); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, { @@ -323,9 +323,9 @@ llama_tool_call_handler llama_tool_call_handler_init( auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space"; if (parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space"; - builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); + builder.add_rule("", first_rule + " (" + subsequent_rule + ")*"); } else { - builder.add_rule("root", first_rule); + builder.add_rule("", first_rule); } }); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); @@ -383,7 +383,7 @@ llama_tool_call_handler llama_tool_call_handler_init( } auto tool_call = "\"\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; - builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + builder.add_rule("", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_trigger_words.push_back(""); }