json: integration test for schemas

This commit is contained in:
ochafik 2024-05-19 00:35:01 +01:00
parent f8db47814b
commit 5a86c6f0e2
3 changed files with 61 additions and 4 deletions

View file

@ -987,7 +987,7 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS) tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o json-schema-to-grammar.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -337,7 +337,7 @@ static void _generate_min_max_int(int min_value, int max_value, std::stringstrea
return; return;
} }
assert(false); throw std::runtime_error("At least one of min_value or max_value must be set");
} }
class SchemaConverter { class SchemaConverter {

View file

@ -11,6 +11,9 @@
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <vector> #include <vector>
#include <json-schema-to-grammar.h>
using json = nlohmann::ordered_json;
static llama_grammar* build_grammar(const std::string & grammar_str) { static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
@ -65,8 +68,8 @@ static bool match_string(const std::string & input, llama_grammar* grammar) {
return false; return false;
} }
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) { static void test(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str()); fprintf(stderr, "⚫ Testing %s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr); fflush(stderr);
auto grammar = build_grammar(grammar_str); auto grammar = build_grammar(grammar_str);
@ -118,8 +121,62 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
// Clean up allocated memory // Clean up allocated memory
llama_grammar_free(grammar); llama_grammar_free(grammar);
} }
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
}
static void test_simple_grammar() { static void test_simple_grammar() {
test_schema(
"simple min 0",
R"""({
"type": "integer",
"minimum": 0
})""",
// Passing strings
{
"0",
"10",
"10000",
},
// Failing strings
{
"-1",
"-10",
"-10000",
"-100000000000000000000000000000000",
"100000000000000000000000000000000",
}
);
test_schema(
"simple min -123 max 42",
R"""({
"type": "integer",
"minimum": -123,
"maximum": 42
})""",
// Passing strings
{
"-123",
"-122",
"-11",
"-1",
"0",
"1",
"10",
"39",
"42",
},
// Failing strings
{
"-124",
"43",
"123",
}
);
// Test case for a simple grammar // Test case for a simple grammar
test_grammar( test_grammar(
"simple grammar", "simple grammar",