From 65176e78ec900416cd59053a790c49fced906189 Mon Sep 17 00:00:00 2001 From: Haggai Nuchi Date: Sat, 4 May 2024 22:43:08 -0700 Subject: [PATCH] Add left recursion check: quit early instead of going into an infinite loop --- llama.cpp | 72 ++++++++++++++++++++++++++++++ tests/test-grammar-integration.cpp | 37 +++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/llama.cpp b/llama.cpp index e91ad7285..2571bb098 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13178,6 +13178,25 @@ static std::vector llama_grammar_reject_candidates( // grammar - external // +enum detect_left_recursion_status { + // haven't searched this nonterminal + LLAMA_LEFT_REC_NOT_SEARCHED = 0, + + // searching this nonterminal in progress + LLAMA_LEFT_REC_IN_PROGRESS = 1, + + // finished searching this nonterminal + LLAMA_LEFT_REC_FINISHED_SEARCH = 2, + + // detected a cycle + LLAMA_LEFT_REC_FOUND_CYCLE = 3, +}; + +static void detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited); + struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, size_t n_rules, @@ -13193,6 +13212,17 @@ struct llama_grammar * llama_grammar_init( vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); } + // Check for left recursion + std::vector rules_visited(n_rules); + for (size_t i = 0; i < n_rules; i++) { + detect_left_recursion(vec_rules, i, &rules_visited); + } + + auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE); + if (iter != rules_visited.end()) { + throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin()))); + } + // loop over alternates of start rule to build initial stacks std::vector> stacks; pos = vec_rules[start_rule_index].data(); @@ -13215,9 +13245,51 @@ struct llama_grammar * llama_grammar_init( } } while (true); + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } +static void detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited) { + + int visit_status = (*rules_visited)[rule_index]; + if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) { + // in progress -- we're in a cycle + (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE; + return; + } else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) { + // haven't visited yet. mark in progress, recurse, then mark complete. + // mark in progress + (*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS; + + // recurse + const std::vector & rule = rules[rule_index]; + size_t i = 0; + do { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF) { + detect_left_recursion(rules, (size_t)rule[i].value, rules_visited); + } + while (!llama_grammar_is_end_of_sequence(&rule[i])) { + i++; + } + i++; + } while (i < rule.size()); + + // mark complete, but only if the recursive call didn't mark a cycle. + // that doesn't mean there's definitely no cycle *for this rule* -- the recursive call + // might have found a different cycle and stopped early. + if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) { + (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH; + } + } + + return; +} + void llama_grammar_free(struct llama_grammar * grammar) { delete grammar; } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 1a4004e2a..0ae7b0972 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -28,6 +28,18 @@ static llama_grammar* build_grammar(const std::string & grammar_str) { return grammar; } +static bool test_build_grammar_fails(const std::string & grammar_str) { + bool grammar_fails = false; + try { + build_grammar(grammar_str); + fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str()); + } catch (const std::exception & err) { + grammar_fails = true; + fprintf(stdout, "✅︎\n"); + } + return grammar_fails; +} + static bool match_string(const std::string & input, llama_grammar* grammar) { auto decoded = decode_utf8(input, {}); @@ -320,6 +332,30 @@ number ::= [0-9]+)"""; fprintf(stderr, " ✅︎ Passed\n"); } +static void test_failure_left_recursion() { + fprintf(stderr, "⚫ Testing left recursion detection:\n"); + + // Test simple left recursion detection + const std::string simple_str = R"""(root ::= "a" | root "a")"""; + assert(test_build_grammar_fails(simple_str)); + + // Test more complicated left recursion detection + const std::string medium_str = R"""( +root ::= asdf +asdf ::= "a" | asdf "a" +)"""; + assert(test_build_grammar_fails(medium_str)); + + // Test even more complicated left recursion detection + const std::string hard_str = R"""( +root ::= asdf +asdf ::= "a" | foo "b" +foo ::= "c" | asdf "d" | "e")"""; + assert(test_build_grammar_fails(hard_str)); + + fprintf(stderr, " ✅︎ Passed\n"); +} + int main() { fprintf(stdout, "Running grammar integration tests...\n"); test_simple_grammar(); @@ -327,6 +363,7 @@ int main() { test_quantifiers(); test_failure_missing_root(); test_failure_missing_reference(); + test_failure_left_recursion(); fprintf(stdout, "All tests passed.\n"); return 0; }