Added integration tests for GBNF parser to validate correctness of parsing, as well as correctness of string matching. Intended for use to pin behavior while working on performance improvements.
This commit is contained in:
parent
7a2c92637a
commit
345eae3021
3 changed files with 242 additions and 1 deletions
6
Makefile
6
Makefile
|
@ -10,7 +10,7 @@ TEST_TARGETS = \
|
|||
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
|
||||
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
|
||||
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
|
||||
tests/test-json-schema-to-grammar
|
||||
tests/test-json-schema-to-grammar tests/test-grammar-integration
|
||||
|
||||
# Code coverage output files
|
||||
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
|
||||
|
@ -918,6 +918,10 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
|
|
@ -59,6 +59,7 @@ llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 AR
|
|||
|
||||
llama_test(test-grammar-parser.cpp)
|
||||
llama_test(test-llama-grammar.cpp)
|
||||
llama_test(test-grammar-integration.cpp)
|
||||
llama_test(test-grad0.cpp)
|
||||
# llama_test(test-opt.cpp) # SLOW
|
||||
llama_test(test-backend-ops.cpp)
|
||||
|
|
236
tests/test-grammar-integration.cpp
Normal file
236
tests/test-grammar-integration.cpp
Normal file
|
@ -0,0 +1,236 @@
|
|||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include "llama.cpp" // TODO: not great
|
||||
#include "grammar-parser.h"
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
|
||||
static void test_failure_missing_root() {
|
||||
// Test case for a grammar that is missing a root rule
|
||||
const std::string grammar_str = R"""(rot ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
term ::= number
|
||||
number ::= [0-9]+)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
|
||||
// Ensure we do NOT have a root node
|
||||
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
|
||||
}
|
||||
|
||||
static void test_failure_missing_reference() {
|
||||
// Test case for a grammar that is missing a referenced rule
|
||||
const std::string grammar_str = R"""(root ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
term ::= numero
|
||||
number ::= [0-9]+)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we did NOT parsed correctly
|
||||
assert(parsed_grammar.rules.empty());
|
||||
|
||||
fprintf(stderr, "^ If previous line displays an error, then this test passed.\n");
|
||||
}
|
||||
|
||||
static void test_simple_grammar() {
|
||||
// Test case for a simple grammar
|
||||
const std::string grammar_str = R"""(root ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
term ::= number
|
||||
number ::= [0-9]+)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
|
||||
// Ensure we have a root node
|
||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
||||
|
||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
||||
llama_grammar* grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
|
||||
std::string input = "123+456";
|
||||
|
||||
auto decoded = decode_utf8(input, {});
|
||||
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
auto prev_stacks = grammar->stacks;
|
||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||
assert(!grammar->stacks.empty());
|
||||
}
|
||||
|
||||
bool completed_grammar = false;
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
completed_grammar = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert(completed_grammar);
|
||||
|
||||
// Clean up allocated memory
|
||||
llama_grammar_free(grammar);
|
||||
}
|
||||
|
||||
static void test_complex_grammar() {
|
||||
// Test case for a more complex grammar
|
||||
const std::string grammar_str = R"""(root ::= expression
|
||||
expression ::= term ws (("+"|"-") ws term)*
|
||||
term ::= factor ws (("*"|"/") ws factor)*
|
||||
factor ::= number | variable | "(" expression ")" | function-call
|
||||
number ::= [0-9]+
|
||||
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
|
||||
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
|
||||
ws ::= [ \t\n\r]?)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
|
||||
// Ensure we have a root node
|
||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
||||
|
||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
||||
llama_grammar* grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
|
||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||
auto original_stacks = grammar->stacks;
|
||||
|
||||
// Test a few strings
|
||||
std::vector<std::string> test_strings_pass = {
|
||||
"42",
|
||||
"1*2*3*4*5",
|
||||
"x",
|
||||
"x+10",
|
||||
"x1+y2",
|
||||
"(a+b)*(c-d)",
|
||||
"func()",
|
||||
"func(x,y+2)",
|
||||
"a*(b+c)-d/e",
|
||||
"f(g(x),h(y,z))",
|
||||
"x + 10",
|
||||
"x1 + y2",
|
||||
"(a + b) * (c - d)",
|
||||
"func()",
|
||||
"func(x, y + 2)",
|
||||
"a * (b + c) - d / e",
|
||||
"f(g(x), h(y, z))",
|
||||
"123+456",
|
||||
"123*456*789-123/456+789*123",
|
||||
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
|
||||
};
|
||||
|
||||
std::vector<std::string> test_strings_fail = {
|
||||
"+",
|
||||
"/ 3x",
|
||||
"x + + y",
|
||||
"a * / b",
|
||||
"func(,)",
|
||||
"func(x y)",
|
||||
"(a + b",
|
||||
"x + y)",
|
||||
"a + b * (c - d",
|
||||
"42 +",
|
||||
"x +",
|
||||
"x + 10 +",
|
||||
"(a + b) * (c - d",
|
||||
"func(",
|
||||
"func(x, y + 2",
|
||||
"a * (b + c) - d /",
|
||||
"f(g(x), h(y, z)",
|
||||
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
|
||||
};
|
||||
|
||||
for (const auto & test_string : test_strings_pass) {
|
||||
auto decoded = decode_utf8(test_string, {});
|
||||
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
int pos = 0;
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
++pos;
|
||||
auto prev_stacks = grammar->stacks;
|
||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||
|
||||
// Expect that each code point will not cause the grammar to fail
|
||||
if (grammar->stacks.empty()) {
|
||||
fprintf(stdout, "Error at position %d\n", pos);
|
||||
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
|
||||
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
|
||||
}
|
||||
assert(!grammar->stacks.empty());
|
||||
}
|
||||
|
||||
bool completed_grammar = false;
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
completed_grammar = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert(completed_grammar);
|
||||
|
||||
// Reset the grammar stacks
|
||||
grammar->stacks = original_stacks;
|
||||
}
|
||||
|
||||
for (const auto & test_string : test_strings_fail) {
|
||||
auto decoded = decode_utf8(test_string, {});
|
||||
|
||||
const auto & code_points = decoded.first;
|
||||
bool parse_failed = false;
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
auto prev_stacks = grammar->stacks;
|
||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||
if (grammar->stacks.empty()) {
|
||||
parse_failed = true;
|
||||
break;
|
||||
}
|
||||
assert(!grammar->stacks.empty());
|
||||
}
|
||||
|
||||
bool completed_grammar = false;
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
completed_grammar = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that the grammar is not completed, or that each string failed to match as-expected
|
||||
assert((!completed_grammar) || parse_failed);
|
||||
|
||||
// Reset the grammar stacks
|
||||
grammar->stacks = original_stacks;
|
||||
}
|
||||
|
||||
// Clean up allocated memory
|
||||
llama_grammar_free(grammar);
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_simple_grammar();
|
||||
test_complex_grammar();
|
||||
test_failure_missing_root();
|
||||
test_failure_missing_reference();
|
||||
// Add more test cases as needed
|
||||
return 0;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue