diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 26e39eefa..f6ec82587 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -153,6 +153,115 @@ ws ::= [ \t\n\r]?)"""; llama_grammar_free(grammar); } +static void test_quantifiers() { + // Populate test data with grammar strings and their associated collections of expected passing and failing strings + const std::vector< + std::tuple< + std::string, + std::vector, + std::vector>> + test_data = { + { + // Grammar + R"""(root ::= "a"*)""", + // Passing strings + { + "", + "a", + "aaaaa", + "aaaaaaaaaaaaaaaaaa", + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + }, + // Failing strings + { + "b", + "ab", + "aab", + "ba", + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" + } + }, + { + // Grammar + R"""(root ::= "a"+)""", + // Passing strings + { + "a", + "aaaaa", + "aaaaaaaaaaaaaaaaaa", + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + }, + // Failing strings + { + "", + "b", + "ab", + "aab", + "ba", + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" + } + }, + { + // Grammar + R"""(root ::= "a"?)""", + // Passing strings + { + "", + "a" + }, + // Failing strings + { + "b", + "ab", + "aa", + "ba", + } + } + }; + + for (const auto & test_datum : test_data) { + const auto & [grammar_str, passing_strings, failing_strings] = test_datum; + + auto grammar = build_grammar(grammar_str); + + // Save the original grammar stacks so that we can reset after every new string we want to test + auto original_stacks = grammar->stacks; + + // Passing strings + for (const auto & test_string : passing_strings) { + bool matched = match_string(test_string, grammar); + + if (!matched) { + fprintf(stderr, "Against grammar: %s\n", grammar_str.c_str()); + fprintf(stderr, "Failed to match string: %s\n", test_string.c_str()); + } + + assert(matched); + + // Reset the grammar stacks + grammar->stacks = original_stacks; + } + + // Failing strings + for (const auto & test_string : failing_strings) { + bool matched = match_string(test_string, grammar); + + if (matched) { + fprintf(stderr, "Against grammar: %s\n", grammar_str.c_str()); + fprintf(stderr, "Improperly matched string: %s\n", test_string.c_str()); + } + + assert(!matched); + + // Reset the grammar stacks + grammar->stacks = original_stacks; + } + + // Clean up allocated memory + llama_grammar_free(grammar); + } +} + static void test_failure_missing_root() { // Test case for a grammar that is missing a root rule const std::string grammar_str = R"""(rot ::= expr @@ -189,6 +298,7 @@ number ::= [0-9]+)"""; int main() { test_simple_grammar(); test_complex_grammar(); + test_quantifiers(); test_failure_missing_root(); test_failure_missing_reference(); fprintf(stdout, "All tests passed.\n");