support alternates in root rule

This commit is contained in:
Evan Jones 2023-06-14 23:53:12 -04:00
parent 3e78f0071a
commit 421c6e1ca1

View file

@ -1870,7 +1870,7 @@ static void llama_grammar_advance_stack(
new_stack.push_back(pos + 2); new_stack.push_back(pos + 2);
} }
if (subpos[1]) { if (subpos[1]) {
// if the referenced rule is nonempty, add that to the stack // if this alternate is nonempty, add that to the stack
new_stack.push_back(subpos + 1); new_stack.push_back(subpos + 1);
} }
llama_grammar_advance_stack(rules, new_stack, new_stacks); llama_grammar_advance_stack(rules, new_stack, new_stacks);
@ -1980,16 +1980,22 @@ struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_r
pos++; pos++;
} }
// TODO: handle if start rule has alternates
const uint16_t * start_rule = rules[start_rule_id]; const uint16_t * start_rule = rules[start_rule_id];
// rule starts with rule id and 1st alternate's size; skip that so initial LLAMA_ASSERT(*start_rule == start_rule_id);
// stack starts at 1st element in 1st alternate
LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]);
const std::vector<const uint16_t *> stack = { start_rule + 2 };
// loop over alternates of start rule to build initial stacks
pos = start_rule + 1;
std::vector<std::vector<const uint16_t *>> stacks; std::vector<std::vector<const uint16_t *>> stacks;
while (*pos) {
std::vector<const uint16_t *> stack;
if (pos[1]) {
// if alernate is nonempty, add to stack
stack.push_back(pos + 1);
}
llama_grammar_advance_stack(rules, stack, stacks); llama_grammar_advance_stack(rules, stack, stacks);
pos += 1 + *pos;
}
return new llama_grammar{ rules, stacks }; return new llama_grammar{ rules, stacks };
} }