diff --git a/llama.cpp b/llama.cpp index 2571bb098..e5c9bd2b3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13174,29 +13174,69 @@ static std::vector llama_grammar_reject_candidates( return rejects; } +static bool llama_grammar_detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty); + +static bool llama_grammar_detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const std::vector & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + return false; +} + // // grammar - external // -enum detect_left_recursion_status { - // haven't searched this nonterminal - LLAMA_LEFT_REC_NOT_SEARCHED = 0, - - // searching this nonterminal in progress - LLAMA_LEFT_REC_IN_PROGRESS = 1, - - // finished searching this nonterminal - LLAMA_LEFT_REC_FINISHED_SEARCH = 2, - - // detected a cycle - LLAMA_LEFT_REC_FOUND_CYCLE = 3, -}; - -static void detect_left_recursion( - const std::vector> & rules, - size_t rule_index, - std::vector * rules_visited); - struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, size_t n_rules, @@ -13213,14 +13253,16 @@ struct llama_grammar * llama_grammar_init( } // Check for left recursion - std::vector rules_visited(n_rules); + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); for (size_t i = 0; i < n_rules; i++) { - detect_left_recursion(vec_rules, i, &rules_visited); - } - - auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE); - if (iter != rules_visited.end()) { - throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin()))); + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i)); + } } // loop over alternates of start rule to build initial stacks @@ -13251,45 +13293,6 @@ struct llama_grammar * llama_grammar_init( return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } -static void detect_left_recursion( - const std::vector> & rules, - size_t rule_index, - std::vector * rules_visited) { - - int visit_status = (*rules_visited)[rule_index]; - if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) { - // in progress -- we're in a cycle - (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE; - return; - } else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) { - // haven't visited yet. mark in progress, recurse, then mark complete. - // mark in progress - (*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS; - - // recurse - const std::vector & rule = rules[rule_index]; - size_t i = 0; - do { - if (rule[i].type == LLAMA_GRETYPE_RULE_REF) { - detect_left_recursion(rules, (size_t)rule[i].value, rules_visited); - } - while (!llama_grammar_is_end_of_sequence(&rule[i])) { - i++; - } - i++; - } while (i < rule.size()); - - // mark complete, but only if the recursive call didn't mark a cycle. - // that doesn't mean there's definitely no cycle *for this rule* -- the recursive call - // might have found a different cycle and stopped early. - if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) { - (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH; - } - } - - return; -} - void llama_grammar_free(struct llama_grammar * grammar) { delete grammar; } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 0ae7b0972..01c5bb27a 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -29,13 +29,14 @@ static llama_grammar* build_grammar(const std::string & grammar_str) { } static bool test_build_grammar_fails(const std::string & grammar_str) { + fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str()); bool grammar_fails = false; try { build_grammar(grammar_str); - fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str()); + fprintf(stderr, " ❌ Expected build failure, but succeeded\n"); } catch (const std::exception & err) { grammar_fails = true; - fprintf(stdout, "✅︎\n"); + fprintf(stdout, " ✅︎\n"); } return grammar_fails; } @@ -353,6 +354,14 @@ asdf ::= "a" | foo "b" foo ::= "c" | asdf "d" | "e")"""; assert(test_build_grammar_fails(hard_str)); + // Test yet even more complicated left recursion detection + const std::string hardest_str = R"""( +root ::= asdf +asdf ::= "a" | foo "b" +foo ::= "c" | empty asdf "d" | "e" +empty ::= "blah" | )"""; + assert(test_build_grammar_fails(hardest_str)); + fprintf(stderr, " ✅︎ Passed\n"); }