From 9cd07c2f9d575096cd68d3cfee5bc84b25803163 Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Fri, 12 Apr 2024 16:17:07 -0400 Subject: [PATCH] Cleaning up integration tests to share code between tests and make it simpler to add new tests. --- tests/test-grammar-integration.cpp | 121 +++++++++-------------------- 1 file changed, 37 insertions(+), 84 deletions(-) diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 2d8f228e3..26e39eefa 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -11,14 +11,8 @@ #include #include -static void test_simple_grammar() { - // Test case for a simple grammar - const std::string grammar_str = R"""(root ::= expr -expr ::= term ("+" term)* -term ::= number -number ::= [0-9]+)"""; - - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); +static llama_grammar* build_grammar(const std::string & grammar_str) { + auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); // Ensure we parsed correctly assert(!parsed_grammar.rules.empty()); @@ -30,8 +24,10 @@ number ::= [0-9]+)"""; llama_grammar* grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - std::string input = "123+456"; + return grammar; +} +static bool match_string(const std::string & input, llama_grammar* grammar) { auto decoded = decode_utf8(input, {}); const auto & code_points = decoded.first; @@ -39,19 +35,34 @@ number ::= [0-9]+)"""; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { auto prev_stacks = grammar->stacks; llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks); - assert(!grammar->stacks.empty()); - } - - bool completed_grammar = false; - - for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - completed_grammar = true; - break; + if (grammar->stacks.empty()) { + // no stacks means that the grammar failed to match at this point + return false; } } - assert(completed_grammar); + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + // An empty stack means that the grammar has been completed + return true; + } + } + + return false; +} + +static void test_simple_grammar() { + // Test case for a simple grammar + const std::string grammar_str = R"""(root ::= expr +expr ::= term ("+" term)* +term ::= number +number ::= [0-9]+)"""; + + auto grammar = build_grammar(grammar_str); + + bool matched = match_string("123+456", grammar); + + assert(matched); // Clean up allocated memory llama_grammar_free(grammar); @@ -68,17 +79,7 @@ variable ::= [a-zA-Z_][a-zA-Z0-9_]* function-call ::= variable ws "(" (expression ("," ws expression)*)? ")" ws ::= [ \t\n\r]?)"""; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // Ensure we parsed correctly - assert(!parsed_grammar.rules.empty()); - - // Ensure we have a root node - assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); - - std::vector grammar_rules(parsed_grammar.c_rules()); - llama_grammar* grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + auto grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test auto original_stacks = grammar->stacks; @@ -130,35 +131,9 @@ ws ::= [ \t\n\r]?)"""; // Passing strings for (const auto & test_string : test_strings_pass) { - auto decoded = decode_utf8(test_string, {}); + bool matched = match_string(test_string, grammar); - const auto & code_points = decoded.first; - - int pos = 0; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - ++pos; - auto prev_stacks = grammar->stacks; - llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks); - - // Expect that each code point will not cause the grammar to fail - if (grammar->stacks.empty()) { - fprintf(stdout, "Error at position %d\n", pos); - fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str()); - fprintf(stderr, "Input string is %s:\n", test_string.c_str()); - } - assert(!grammar->stacks.empty()); - } - - bool completed_grammar = false; - - for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - completed_grammar = true; - break; - } - } - - assert(completed_grammar); + assert(matched); // Reset the grammar stacks grammar->stacks = original_stacks; @@ -166,32 +141,9 @@ ws ::= [ \t\n\r]?)"""; // Failing strings for (const auto & test_string : test_strings_fail) { - auto decoded = decode_utf8(test_string, {}); + bool matched = match_string(test_string, grammar); - const auto & code_points = decoded.first; - bool parse_failed = false; - - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - auto prev_stacks = grammar->stacks; - llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks); - if (grammar->stacks.empty()) { - parse_failed = true; - break; - } - assert(!grammar->stacks.empty()); - } - - bool completed_grammar = false; - - for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - completed_grammar = true; - break; - } - } - - // Ensure that the grammar is not completed, or that each string failed to match as-expected - assert((!completed_grammar) || parse_failed); + assert(!matched); // Reset the grammar stacks grammar->stacks = original_stacks; @@ -231,7 +183,7 @@ number ::= [0-9]+)"""; // Ensure we did NOT parsed correctly assert(parsed_grammar.rules.empty()); - fprintf(stderr, "End of expected error. Test successful.\n"); + fprintf(stderr, "End of expected error.\n"); } int main() { @@ -239,5 +191,6 @@ int main() { test_complex_grammar(); test_failure_missing_root(); test_failure_missing_reference(); + fprintf(stdout, "All tests passed.\n"); return 0; }