Cleaning up integration tests to share code between tests and make it simpler to add new tests.

This commit is contained in:
Clint Herron 2024-04-12 16:17:07 -04:00
parent 5539e6fdd1
commit 9cd07c2f9d

View file

@ -11,14 +11,8 @@
#include <cassert>
#include <string>
static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
@ -30,8 +24,10 @@ number ::= [0-9]+)""";
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
std::string input = "123+456";
return grammar;
}
static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});
const auto & code_points = decoded.first;
@ -39,19 +35,34 @@ number ::= [0-9]+)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
if (grammar->stacks.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
// An empty stack means that the grammar has been completed
return true;
}
}
assert(completed_grammar);
return false;
}
static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";
auto grammar = build_grammar(grammar_str);
bool matched = match_string("123+456", grammar);
assert(matched);
// Clean up allocated memory
llama_grammar_free(grammar);
@ -68,17 +79,7 @@ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
auto grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
@ -130,35 +131,9 @@ ws ::= [ \t\n\r]?)""";
// Passing strings
for (const auto & test_string : test_strings_pass) {
auto decoded = decode_utf8(test_string, {});
bool matched = match_string(test_string, grammar);
const auto & code_points = decoded.first;
int pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
assert(completed_grammar);
assert(matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
@ -166,32 +141,9 @@ ws ::= [ \t\n\r]?)""";
// Failing strings
for (const auto & test_string : test_strings_fail) {
auto decoded = decode_utf8(test_string, {});
bool matched = match_string(test_string, grammar);
const auto & code_points = decoded.first;
bool parse_failed = false;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
}
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);
assert(!matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
@ -231,7 +183,7 @@ number ::= [0-9]+)""";
// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());
fprintf(stderr, "End of expected error. Test successful.\n");
fprintf(stderr, "End of expected error.\n");
}
int main() {
@ -239,5 +191,6 @@ int main() {
test_complex_grammar();
test_failure_missing_root();
test_failure_missing_reference();
fprintf(stdout, "All tests passed.\n");
return 0;
}