grammars: refactor parser test

This commit is contained in:
Olivier Chafik 2024-04-12 17:00:31 +01:00
parent 6b5518c9da
commit 0ceb69afbc

View file

@ -7,28 +7,10 @@
#include <cassert> #include <cassert>
int main()
{
grammar_parser::parse_state parsed_grammar;
const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+)""";
parsed_grammar = grammar_parser::parse(grammar_bytes);
std::vector<std::pair<std::string, uint32_t>> expected = {
{"expr", 2},
{"expr_5", 5},
{"expr_6", 6},
{"root", 0},
{"root_1", 1},
{"root_4", 4},
{"term", 3},
{"term_7", 7},
};
static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
uint32_t index = 0; uint32_t index = 0;
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
{ {
std::string key = it->first; std::string key = it->first;
@ -47,7 +29,47 @@ term ::= [0-9]+)""";
index++; index++;
} }
std::vector<llama_grammar_element> expected_rules = {
index = 0;
for (auto rule : parsed_grammar.rules)
{
// compare rule to expected rule
for (uint32_t i = 0; i < rule.size(); i++)
{
llama_grammar_element element = rule[i];
llama_grammar_element expected_element = expected_rules[index];
// pretty print error message before asserting
if (expected_element.type != element.type || expected_element.value != element.value)
{
fprintf(stderr, "index: %u\n", index);
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value);
fprintf(stderr, "expected_element != actual_element\n");
}
assert(expected_element.type == element.type && expected_element.value == element.value);
index++;
}
}
}
int main()
{
verify_parsing(R"""(
root ::= (expr "=" term "\n")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+
)""", {
{"expr", 2},
{"expr_5", 5},
{"expr_6", 6},
{"root", 0},
{"root_1", 1},
{"root_4", 4},
{"term", 3},
{"term_7", 7},
}, {
{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 4},
{LLAMA_GRETYPE_END, 0}, {LLAMA_GRETYPE_END, 0},
{LLAMA_GRETYPE_RULE_REF, 2}, {LLAMA_GRETYPE_RULE_REF, 2},
@ -82,43 +104,16 @@ term ::= [0-9]+)""";
{LLAMA_GRETYPE_CHAR, 48}, {LLAMA_GRETYPE_CHAR, 48},
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
{LLAMA_GRETYPE_END, 0}, {LLAMA_GRETYPE_END, 0},
}; });
index = 0; verify_parsing(R"""(
for (auto rule : parsed_grammar.rules) root ::= (expr "=" ws term "\n")+
{ expr ::= term ([-+*/] term)*
// compare rule to expected rule term ::= ident | num | "(" ws expr ")" ws
for (uint32_t i = 0; i < rule.size(); i++) ident ::= [a-z] [a-z0-9_]* ws
{ num ::= [0-9]+ ws
llama_grammar_element element = rule[i]; ws ::= [ \t\n]*
llama_grammar_element expected_element = expected_rules[index]; )""", {
// pretty print error message before asserting
if (expected_element.type != element.type || expected_element.value != element.value)
{
fprintf(stderr, "index: %u\n", index);
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value);
fprintf(stderr, "expected_element != actual_element\n");
}
assert(expected_element.type == element.type && expected_element.value == element.value);
index++;
}
}
const char *longer_grammar_bytes = R"""(
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
)""";
parsed_grammar = grammar_parser::parse(longer_grammar_bytes);
expected = {
{"expr", 2}, {"expr", 2},
{"expr_6", 6}, {"expr_6", 6},
{"expr_7", 7}, {"expr_7", 7},
@ -132,28 +127,7 @@ term ::= [0-9]+)""";
{"term", 4}, {"term", 4},
{"ws", 3}, {"ws", 3},
{"ws_12", 12}, {"ws_12", 12},
}; }, {
index = 0;
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
{
std::string key = it->first;
uint32_t value = it->second;
std::pair<std::string, uint32_t> expected_pair = expected[index];
// pretty print error message before asserting
if (expected_pair.first != key || expected_pair.second != value)
{
fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second);
fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value);
fprintf(stderr, "expected_pair != actual_pair\n");
}
assert(expected_pair.first == key && expected_pair.second == value);
index++;
}
expected_rules = {
{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_RULE_REF, 5},
{LLAMA_GRETYPE_END, 0}, {LLAMA_GRETYPE_END, 0},
{LLAMA_GRETYPE_RULE_REF, 2}, {LLAMA_GRETYPE_RULE_REF, 2},
@ -221,30 +195,7 @@ term ::= [0-9]+)""";
{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_RULE_REF, 12},
{LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0}, {LLAMA_GRETYPE_END, 0},
}; });
index = 0;
for (auto rule : parsed_grammar.rules)
{
// compare rule to expected rule
for (uint32_t i = 0; i < rule.size(); i++)
{
llama_grammar_element element = rule[i];
llama_grammar_element expected_element = expected_rules[index];
// pretty print error message before asserting
if (expected_element.type != element.type || expected_element.value != element.value)
{
fprintf(stderr, "index: %u\n", index);
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value);
fprintf(stderr, "expected_element != actual_element\n");
}
assert(expected_element.type == element.type && expected_element.value == element.value);
index++;
}
}
return 0; return 0;
} }