Add left recursion check: quit early instead of going into an infinite loop (#7083)

* Add left recursion check: quit early instead of going into an infinite loop

* Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty

* Remove unnecessary declaration
This commit is contained in:
Haggai Nuchi 2024-05-13 22:25:56 -07:00 committed by GitHub
parent 27f65d6267
commit e0f556186b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 114 additions and 0 deletions

View file

@ -28,6 +28,19 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
return grammar;
}
static bool test_build_grammar_fails(const std::string & grammar_str) {
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
bool grammar_fails = false;
try {
build_grammar(grammar_str);
fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
} catch (const std::exception & err) {
grammar_fails = true;
fprintf(stdout, " ✅︎\n");
}
return grammar_fails;
}
static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});
@ -320,6 +333,38 @@ number ::= [0-9]+)""";
fprintf(stderr, " ✅︎ Passed\n");
}
static void test_failure_left_recursion() {
fprintf(stderr, "⚫ Testing left recursion detection:\n");
// Test simple left recursion detection
const std::string simple_str = R"""(root ::= "a" | root "a")""";
assert(test_build_grammar_fails(simple_str));
// Test more complicated left recursion detection
const std::string medium_str = R"""(
root ::= asdf
asdf ::= "a" | asdf "a"
)""";
assert(test_build_grammar_fails(medium_str));
// Test even more complicated left recursion detection
const std::string hard_str = R"""(
root ::= asdf
asdf ::= "a" | foo "b"
foo ::= "c" | asdf "d" | "e")""";
assert(test_build_grammar_fails(hard_str));
// Test yet even more complicated left recursion detection
const std::string hardest_str = R"""(
root ::= asdf
asdf ::= "a" | foo "b"
foo ::= "c" | empty asdf "d" | "e"
empty ::= "blah" | )""";
assert(test_build_grammar_fails(hardest_str));
fprintf(stderr, " ✅︎ Passed\n");
}
int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
@ -327,6 +372,7 @@ int main() {
test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
test_failure_left_recursion();
fprintf(stdout, "All tests passed.\n");
return 0;
}