diff --git a/Makefile b/Makefile index 895c62f84..e81dc8240 100644 --- a/Makefile +++ b/Makefile @@ -987,7 +987,7 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar- $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS) +tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o json-schema-to-grammar.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 8f5e35d49..42faa8adb 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -337,7 +337,7 @@ static void _generate_min_max_int(int min_value, int max_value, std::stringstrea return; } - assert(false); + throw std::runtime_error("At least one of min_value or max_value must be set"); } class SchemaConverter { diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 9bdab05af..ca9b9c110 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -11,6 +11,9 @@ #include #include #include +#include + +using json = nlohmann::ordered_json; static llama_grammar* build_grammar(const std::string & grammar_str) { auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); @@ -65,8 +68,8 @@ static bool match_string(const std::string & input, llama_grammar* grammar) { return false; } -static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) { - fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str()); +static void test(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) { + fprintf(stderr, "⚫ Testing %s\n", test_desc.c_str(), grammar_str.c_str()); fflush(stderr); auto grammar = build_grammar(grammar_str); @@ -118,8 +121,62 @@ static void test_grammar(const std::string & test_desc, const std::string & gram // Clean up allocated memory llama_grammar_free(grammar); } +static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) { + test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); +} +static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector & passing_strings, const std::vector & failing_strings) { + test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings); +} static void test_simple_grammar() { + test_schema( + "simple min 0", + R"""({ + "type": "integer", + "minimum": 0 + })""", + // Passing strings + { + "0", + "10", + "10000", + }, + // Failing strings + { + "-1", + "-10", + "-10000", + "-100000000000000000000000000000000", + "100000000000000000000000000000000", + } + ); + test_schema( + "simple min -123 max 42", + R"""({ + "type": "integer", + "minimum": -123, + "maximum": 42 + })""", + // Passing strings + { + "-123", + "-122", + "-11", + "-1", + "0", + "1", + "10", + "39", + "42", + }, + // Failing strings + { + "-124", + "43", + "123", + } + ); + // Test case for a simple grammar test_grammar( "simple grammar",