Cleaning up integration tests to share code between tests and make it simpler to add new tests.
This commit is contained in:
parent
5539e6fdd1
commit
9cd07c2f9d
1 changed files with 37 additions and 84 deletions
|
@ -11,14 +11,8 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
static void test_simple_grammar() {
|
static llama_grammar* build_grammar(const std::string & grammar_str) {
|
||||||
// Test case for a simple grammar
|
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||||
const std::string grammar_str = R"""(root ::= expr
|
|
||||||
expr ::= term ("+" term)*
|
|
||||||
term ::= number
|
|
||||||
number ::= [0-9]+)""";
|
|
||||||
|
|
||||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
|
||||||
|
|
||||||
// Ensure we parsed correctly
|
// Ensure we parsed correctly
|
||||||
assert(!parsed_grammar.rules.empty());
|
assert(!parsed_grammar.rules.empty());
|
||||||
|
@ -30,8 +24,10 @@ number ::= [0-9]+)""";
|
||||||
llama_grammar* grammar = llama_grammar_init(
|
llama_grammar* grammar = llama_grammar_init(
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
|
||||||
std::string input = "123+456";
|
return grammar;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool match_string(const std::string & input, llama_grammar* grammar) {
|
||||||
auto decoded = decode_utf8(input, {});
|
auto decoded = decode_utf8(input, {});
|
||||||
|
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
@ -39,19 +35,34 @@ number ::= [0-9]+)""";
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto prev_stacks = grammar->stacks;
|
auto prev_stacks = grammar->stacks;
|
||||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||||
assert(!grammar->stacks.empty());
|
if (grammar->stacks.empty()) {
|
||||||
}
|
// no stacks means that the grammar failed to match at this point
|
||||||
|
return false;
|
||||||
bool completed_grammar = false;
|
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
|
||||||
if (stack.empty()) {
|
|
||||||
completed_grammar = true;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(completed_grammar);
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
// An empty stack means that the grammar has been completed
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_simple_grammar() {
|
||||||
|
// Test case for a simple grammar
|
||||||
|
const std::string grammar_str = R"""(root ::= expr
|
||||||
|
expr ::= term ("+" term)*
|
||||||
|
term ::= number
|
||||||
|
number ::= [0-9]+)""";
|
||||||
|
|
||||||
|
auto grammar = build_grammar(grammar_str);
|
||||||
|
|
||||||
|
bool matched = match_string("123+456", grammar);
|
||||||
|
|
||||||
|
assert(matched);
|
||||||
|
|
||||||
// Clean up allocated memory
|
// Clean up allocated memory
|
||||||
llama_grammar_free(grammar);
|
llama_grammar_free(grammar);
|
||||||
|
@ -68,17 +79,7 @@ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
|
||||||
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
|
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
|
||||||
ws ::= [ \t\n\r]?)""";
|
ws ::= [ \t\n\r]?)""";
|
||||||
|
|
||||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
auto grammar = build_grammar(grammar_str);
|
||||||
|
|
||||||
// Ensure we parsed correctly
|
|
||||||
assert(!parsed_grammar.rules.empty());
|
|
||||||
|
|
||||||
// Ensure we have a root node
|
|
||||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
llama_grammar* grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
|
|
||||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||||
auto original_stacks = grammar->stacks;
|
auto original_stacks = grammar->stacks;
|
||||||
|
@ -130,35 +131,9 @@ ws ::= [ \t\n\r]?)""";
|
||||||
|
|
||||||
// Passing strings
|
// Passing strings
|
||||||
for (const auto & test_string : test_strings_pass) {
|
for (const auto & test_string : test_strings_pass) {
|
||||||
auto decoded = decode_utf8(test_string, {});
|
bool matched = match_string(test_string, grammar);
|
||||||
|
|
||||||
const auto & code_points = decoded.first;
|
assert(matched);
|
||||||
|
|
||||||
int pos = 0;
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
|
||||||
++pos;
|
|
||||||
auto prev_stacks = grammar->stacks;
|
|
||||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
|
||||||
|
|
||||||
// Expect that each code point will not cause the grammar to fail
|
|
||||||
if (grammar->stacks.empty()) {
|
|
||||||
fprintf(stdout, "Error at position %d\n", pos);
|
|
||||||
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
|
|
||||||
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
|
|
||||||
}
|
|
||||||
assert(!grammar->stacks.empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool completed_grammar = false;
|
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
|
||||||
if (stack.empty()) {
|
|
||||||
completed_grammar = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(completed_grammar);
|
|
||||||
|
|
||||||
// Reset the grammar stacks
|
// Reset the grammar stacks
|
||||||
grammar->stacks = original_stacks;
|
grammar->stacks = original_stacks;
|
||||||
|
@ -166,32 +141,9 @@ ws ::= [ \t\n\r]?)""";
|
||||||
|
|
||||||
// Failing strings
|
// Failing strings
|
||||||
for (const auto & test_string : test_strings_fail) {
|
for (const auto & test_string : test_strings_fail) {
|
||||||
auto decoded = decode_utf8(test_string, {});
|
bool matched = match_string(test_string, grammar);
|
||||||
|
|
||||||
const auto & code_points = decoded.first;
|
assert(!matched);
|
||||||
bool parse_failed = false;
|
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
|
||||||
auto prev_stacks = grammar->stacks;
|
|
||||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
|
||||||
if (grammar->stacks.empty()) {
|
|
||||||
parse_failed = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
assert(!grammar->stacks.empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool completed_grammar = false;
|
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
|
||||||
if (stack.empty()) {
|
|
||||||
completed_grammar = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that the grammar is not completed, or that each string failed to match as-expected
|
|
||||||
assert((!completed_grammar) || parse_failed);
|
|
||||||
|
|
||||||
// Reset the grammar stacks
|
// Reset the grammar stacks
|
||||||
grammar->stacks = original_stacks;
|
grammar->stacks = original_stacks;
|
||||||
|
@ -231,7 +183,7 @@ number ::= [0-9]+)""";
|
||||||
// Ensure we did NOT parsed correctly
|
// Ensure we did NOT parsed correctly
|
||||||
assert(parsed_grammar.rules.empty());
|
assert(parsed_grammar.rules.empty());
|
||||||
|
|
||||||
fprintf(stderr, "End of expected error. Test successful.\n");
|
fprintf(stderr, "End of expected error.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
@ -239,5 +191,6 @@ int main() {
|
||||||
test_complex_grammar();
|
test_complex_grammar();
|
||||||
test_failure_missing_root();
|
test_failure_missing_root();
|
||||||
test_failure_missing_reference();
|
test_failure_missing_reference();
|
||||||
|
fprintf(stdout, "All tests passed.\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue