diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 91939e276..7422abfc2 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -7,28 +7,10 @@ #include -int main() -{ - grammar_parser::parse_state parsed_grammar; - - const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+ -expr ::= term ([-+*/] term)* -term ::= [0-9]+)"""; - - parsed_grammar = grammar_parser::parse(grammar_bytes); - - std::vector> expected = { - {"expr", 2}, - {"expr_5", 5}, - {"expr_6", 6}, - {"root", 0}, - {"root_1", 1}, - {"root_4", 4}, - {"term", 3}, - {"term_7", 7}, - }; +static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { uint32_t index = 0; + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { std::string key = it->first; @@ -47,7 +29,47 @@ term ::= [0-9]+)"""; index++; } - std::vector expected_rules = { + + index = 0; + for (auto rule : parsed_grammar.rules) + { + // compare rule to expected rule + for (uint32_t i = 0; i < rule.size(); i++) + { + llama_grammar_element element = rule[i]; + llama_grammar_element expected_element = expected_rules[index]; + + // pretty print error message before asserting + if (expected_element.type != element.type || expected_element.value != element.value) + { + fprintf(stderr, "index: %u\n", index); + fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); + fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); + fprintf(stderr, "expected_element != actual_element\n"); + } + + assert(expected_element.type == element.type && expected_element.value == element.value); + index++; + } + } +} + +int main() +{ + verify_parsing(R"""( + root ::= (expr "=" term "\n")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]+ + )""", { + {"expr", 2}, + {"expr_5", 5}, + {"expr_6", 6}, + {"root", 0}, + {"root_1", 1}, + {"root_4", 4}, + {"term", 3}, + {"term_7", 7}, + }, { {LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_END, 0}, {LLAMA_GRETYPE_RULE_REF, 2}, @@ -82,43 +104,16 @@ term ::= [0-9]+)"""; {LLAMA_GRETYPE_CHAR, 48}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, {LLAMA_GRETYPE_END, 0}, - }; + }); - index = 0; - for (auto rule : parsed_grammar.rules) - { - // compare rule to expected rule - for (uint32_t i = 0; i < rule.size(); i++) - { - llama_grammar_element element = rule[i]; - llama_grammar_element expected_element = expected_rules[index]; - - // pretty print error message before asserting - if (expected_element.type != element.type || expected_element.value != element.value) - { - fprintf(stderr, "index: %u\n", index); - fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); - fprintf(stderr, "expected_element != actual_element\n"); - } - - assert(expected_element.type == element.type && expected_element.value == element.value); - index++; - } - } - - const char *longer_grammar_bytes = R"""( - root ::= (expr "=" ws term "\n")+ - expr ::= term ([-+*/] term)* - term ::= ident | num | "(" ws expr ")" ws - ident ::= [a-z] [a-z0-9_]* ws - num ::= [0-9]+ ws - ws ::= [ \t\n]* - )"""; - - parsed_grammar = grammar_parser::parse(longer_grammar_bytes); - - expected = { + verify_parsing(R"""( + root ::= (expr "=" ws term "\n")+ + expr ::= term ([-+*/] term)* + term ::= ident | num | "(" ws expr ")" ws + ident ::= [a-z] [a-z0-9_]* ws + num ::= [0-9]+ ws + ws ::= [ \t\n]* + )""", { {"expr", 2}, {"expr_6", 6}, {"expr_7", 7}, @@ -132,28 +127,7 @@ term ::= [0-9]+)"""; {"term", 4}, {"ws", 3}, {"ws_12", 12}, - }; - - index = 0; - for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) - { - std::string key = it->first; - uint32_t value = it->second; - std::pair expected_pair = expected[index]; - - // pretty print error message before asserting - if (expected_pair.first != key || expected_pair.second != value) - { - fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); - fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); - fprintf(stderr, "expected_pair != actual_pair\n"); - } - - assert(expected_pair.first == key && expected_pair.second == value); - - index++; - } - expected_rules = { + }, { {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}, {LLAMA_GRETYPE_RULE_REF, 2}, @@ -221,30 +195,7 @@ term ::= [0-9]+)"""; {LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - }; - - index = 0; - for (auto rule : parsed_grammar.rules) - { - // compare rule to expected rule - for (uint32_t i = 0; i < rule.size(); i++) - { - llama_grammar_element element = rule[i]; - llama_grammar_element expected_element = expected_rules[index]; - - // pretty print error message before asserting - if (expected_element.type != element.type || expected_element.value != element.value) - { - fprintf(stderr, "index: %u\n", index); - fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); - fprintf(stderr, "expected_element != actual_element\n"); - } - - assert(expected_element.type == element.type && expected_element.value == element.value); - index++; - } - } + }); return 0; }