diff --git a/llama.cpp b/llama.cpp index b0c2270c6..44147d935 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1870,7 +1870,7 @@ static void llama_grammar_advance_stack( new_stack.push_back(pos + 2); } if (subpos[1]) { - // if the referenced rule is nonempty, add that to the stack + // if this alternate is nonempty, add that to the stack new_stack.push_back(subpos + 1); } llama_grammar_advance_stack(rules, new_stack, new_stacks); @@ -1980,16 +1980,22 @@ struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_r pos++; } - // TODO: handle if start rule has alternates const uint16_t * start_rule = rules[start_rule_id]; - // rule starts with rule id and 1st alternate's size; skip that so initial - // stack starts at 1st element in 1st alternate - LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]); - const std::vector stack = { start_rule + 2 }; + LLAMA_ASSERT(*start_rule == start_rule_id); + // loop over alternates of start rule to build initial stacks + pos = start_rule + 1; std::vector> stacks; - llama_grammar_advance_stack(rules, stack, stacks); + while (*pos) { + std::vector stack; + if (pos[1]) { + // if alernate is nonempty, add to stack + stack.push_back(pos + 1); + } + llama_grammar_advance_stack(rules, stack, stacks); + pos += 1 + *pos; + } return new llama_grammar{ rules, stacks }; }