Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty

This commit is contained in:
Haggai Nuchi 2024-05-11 16:28:59 -07:00
parent 65176e78ec
commit 5c5911d69d
2 changed files with 79 additions and 67 deletions

133
llama.cpp
View file

@ -13174,29 +13174,69 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
return rejects; return rejects;
} }
static bool llama_grammar_detect_left_recursion(
const std::vector<std::vector<llama_grammar_element>> & rules,
size_t rule_index,
std::vector<bool> * rules_visited,
std::vector<bool> * rules_in_progress,
std::vector<bool> * rules_may_be_empty);
static bool llama_grammar_detect_left_recursion(
const std::vector<std::vector<llama_grammar_element>> & rules,
size_t rule_index,
std::vector<bool> * rules_visited,
std::vector<bool> * rules_in_progress,
std::vector<bool> * rules_may_be_empty) {
if ((*rules_in_progress)[rule_index]) {
return true;
}
(*rules_in_progress)[rule_index] = true;
const std::vector<llama_grammar_element> & rule = rules[rule_index];
// First check if the rule might produce the empty string. This could be done combined with the second
// step but it's more readable as two steps.
bool at_rule_start = true;
for (size_t i = 0; i < rule.size(); i++) {
if (llama_grammar_is_end_of_sequence(&rule[i])) {
if (at_rule_start) {
(*rules_may_be_empty)[rule_index] = true;
break;
}
at_rule_start = true;
} else {
at_rule_start = false;
}
}
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
// be empty)
bool recurse_into_nonterminal = true;
for (size_t i = 0; i < rule.size(); i++) {
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
return true;
}
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
recurse_into_nonterminal = false;
}
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
recurse_into_nonterminal = true;
} else {
recurse_into_nonterminal = false;
}
}
(*rules_in_progress)[rule_index] = false;
(*rules_visited)[rule_index] = true;
return false;
}
// //
// grammar - external // grammar - external
// //
enum detect_left_recursion_status {
// haven't searched this nonterminal
LLAMA_LEFT_REC_NOT_SEARCHED = 0,
// searching this nonterminal in progress
LLAMA_LEFT_REC_IN_PROGRESS = 1,
// finished searching this nonterminal
LLAMA_LEFT_REC_FINISHED_SEARCH = 2,
// detected a cycle
LLAMA_LEFT_REC_FOUND_CYCLE = 3,
};
static void detect_left_recursion(
const std::vector<std::vector<llama_grammar_element>> & rules,
size_t rule_index,
std::vector<enum detect_left_recursion_status> * rules_visited);
struct llama_grammar * llama_grammar_init( struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules, const llama_grammar_element ** rules,
size_t n_rules, size_t n_rules,
@ -13213,14 +13253,16 @@ struct llama_grammar * llama_grammar_init(
} }
// Check for left recursion // Check for left recursion
std::vector<enum detect_left_recursion_status> rules_visited(n_rules); std::vector<bool> rules_visited(n_rules);
std::vector<bool> rules_in_progress(n_rules);
std::vector<bool> rules_may_be_empty(n_rules);
for (size_t i = 0; i < n_rules; i++) { for (size_t i = 0; i < n_rules; i++) {
detect_left_recursion(vec_rules, i, &rules_visited); if (rules_visited[i]) {
} continue;
}
auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE); if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
if (iter != rules_visited.end()) { throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin()))); }
} }
// loop over alternates of start rule to build initial stacks // loop over alternates of start rule to build initial stacks
@ -13251,45 +13293,6 @@ struct llama_grammar * llama_grammar_init(
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
} }
static void detect_left_recursion(
const std::vector<std::vector<llama_grammar_element>> & rules,
size_t rule_index,
std::vector<enum detect_left_recursion_status> * rules_visited) {
int visit_status = (*rules_visited)[rule_index];
if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) {
// in progress -- we're in a cycle
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE;
return;
} else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) {
// haven't visited yet. mark in progress, recurse, then mark complete.
// mark in progress
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS;
// recurse
const std::vector<llama_grammar_element> & rule = rules[rule_index];
size_t i = 0;
do {
if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
detect_left_recursion(rules, (size_t)rule[i].value, rules_visited);
}
while (!llama_grammar_is_end_of_sequence(&rule[i])) {
i++;
}
i++;
} while (i < rule.size());
// mark complete, but only if the recursive call didn't mark a cycle.
// that doesn't mean there's definitely no cycle *for this rule* -- the recursive call
// might have found a different cycle and stopped early.
if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) {
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH;
}
}
return;
}
void llama_grammar_free(struct llama_grammar * grammar) { void llama_grammar_free(struct llama_grammar * grammar) {
delete grammar; delete grammar;
} }

View file

@ -29,13 +29,14 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
} }
static bool test_build_grammar_fails(const std::string & grammar_str) { static bool test_build_grammar_fails(const std::string & grammar_str) {
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
bool grammar_fails = false; bool grammar_fails = false;
try { try {
build_grammar(grammar_str); build_grammar(grammar_str);
fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str()); fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
} catch (const std::exception & err) { } catch (const std::exception & err) {
grammar_fails = true; grammar_fails = true;
fprintf(stdout, "✅︎\n"); fprintf(stdout, " ✅︎\n");
} }
return grammar_fails; return grammar_fails;
} }
@ -353,6 +354,14 @@ asdf ::= "a" | foo "b"
foo ::= "c" | asdf "d" | "e")"""; foo ::= "c" | asdf "d" | "e")""";
assert(test_build_grammar_fails(hard_str)); assert(test_build_grammar_fails(hard_str));
// Test yet even more complicated left recursion detection
const std::string hardest_str = R"""(
root ::= asdf
asdf ::= "a" | foo "b"
foo ::= "c" | empty asdf "d" | "e"
empty ::= "blah" | )""";
assert(test_build_grammar_fails(hardest_str));
fprintf(stderr, " ✅︎ Passed\n"); fprintf(stderr, " ✅︎ Passed\n");
} }