Merging improved schema test methods added by @ochafik in #7797

This commit is contained in:
HanClinto 2024-06-12 10:42:37 -07:00 committed by Clint Herron
parent acd3c468af
commit d4a63b0538

View file

@ -13,6 +13,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
using json = nlohmann::ordered_json;
static llama_grammar* build_grammar(const std::string & grammar_str) { static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
@ -66,8 +68,8 @@ static bool match_string(const std::string & input, llama_grammar* grammar) {
return false; return false;
} }
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) { static void test(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str()); fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr); fflush(stderr);
auto grammar = build_grammar(grammar_str); auto grammar = build_grammar(grammar_str);
@ -136,6 +138,12 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
// Clean up allocated memory // Clean up allocated memory
llama_grammar_free(grammar); llama_grammar_free(grammar);
} }
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
}
static void test_simple_grammar() { static void test_simple_grammar() {
// Test case for a simple grammar // Test case for a simple grammar
@ -491,14 +499,12 @@ static void test_json_schema() {
// but we convert each json schema to a grammar before parsing. // but we convert each json schema to a grammar before parsing.
// Otherwise, this test structure is the same. // Otherwise, this test structure is the same.
test_grammar( test_schema(
"empty schema (object)", "empty schema (object)",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{} {}
)""" )""",
)),
// Passing strings // Passing strings
{ {
"{}", "{}",
@ -510,11 +516,10 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"exotic formats (list)", "exotic formats (list)",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"items": [ "items": [
{ "format": "date" }, { "format": "date" },
@ -523,8 +528,7 @@ static void test_json_schema() {
{ "format": "date-time" } { "format": "date-time" }
] ]
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
// "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it? // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
@ -540,16 +544,14 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"string", "string",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "string" "type": "string"
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"foo\"", "\"foo\"",
@ -563,17 +565,15 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"string w/ min length 1", "string w/ min length 1",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "string", "type": "string",
"minLength": 1 "minLength": 1
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"foo\"", "\"foo\"",
@ -587,17 +587,15 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"string w/ min length 3", "string w/ min length 3",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "string", "type": "string",
"minLength": 3 "minLength": 3
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"foo\"", "\"foo\"",
@ -612,17 +610,15 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"string w/ max length", "string w/ max length",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "string", "type": "string",
"maxLength": 3 "maxLength": 3
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"foo\"", "\"foo\"",
@ -637,18 +633,16 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"string w/ min & max length", "string w/ min & max length",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "string", "type": "string",
"minLength": 1, "minLength": 1,
"maxLength": 4 "maxLength": 4
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"foo\"", "\"foo\"",
@ -664,16 +658,14 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"boolean", "boolean",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "boolean" "type": "boolean"
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"true", "true",
@ -688,16 +680,14 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"integer", "integer",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "integer" "type": "integer"
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"0", "0",
@ -713,16 +703,14 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"string const", "string const",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"const": "foo" "const": "foo"
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"foo\"", "\"foo\"",
@ -734,16 +722,14 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"non-string const", "non-string const",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"const": true "const": true
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"true", "true",
@ -756,16 +742,14 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"non-string const", "non-string const",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"enum": ["red", "amber", "green", null, 42, ["foo"]] "enum": ["red", "amber", "green", null, 42, ["foo"]]
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"\"red\"", "\"red\"",
@ -783,11 +767,10 @@ static void test_json_schema() {
); );
test_grammar( test_schema(
"min+max items", "min+max items",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"items": { "items": {
"type": ["number", "integer"] "type": ["number", "integer"]
@ -795,8 +778,7 @@ static void test_json_schema() {
"minItems": 3, "minItems": 3,
"maxItems": 5 "maxItems": 5
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"[1, 2, 3]", "[1, 2, 3]",
@ -812,11 +794,10 @@ static void test_json_schema() {
); );
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties) // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_grammar( test_schema(
"object properties", "object properties",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -825,8 +806,7 @@ static void test_json_schema() {
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] } "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
} }
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""", R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
@ -854,11 +834,10 @@ static void test_json_schema() {
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties) // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_grammar( test_schema(
"object properties, additionalProperties: true", "object properties, additionalProperties: true",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -868,8 +847,7 @@ static void test_json_schema() {
}, },
"additionalProperties": true "additionalProperties": true
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
// TODO: Following line should pass and doesn't // TODO: Following line should pass and doesn't
@ -897,11 +875,10 @@ static void test_json_schema() {
); );
// Additional properties: false // Additional properties: false
test_grammar( test_schema(
"required + optional props each in original order", "required + optional props each in original order",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"type": "object", "type": "object",
"properties": { "properties": {
@ -911,8 +888,7 @@ static void test_json_schema() {
}, },
"additionalProperties": false "additionalProperties": false
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
R"""({ "street_name": "Pennsylvania" })""", R"""({ "street_name": "Pennsylvania" })""",
@ -931,11 +907,10 @@ static void test_json_schema() {
} }
); );
test_grammar( test_schema(
"required + optional props each in original order", "required + optional props each in original order",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"properties": { "properties": {
"b": {"type": "string"}, "b": {"type": "string"},
@ -946,8 +921,7 @@ static void test_json_schema() {
"required": ["a", "b"], "required": ["a", "b"],
"additionalProperties": false "additionalProperties": false
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"{\"b\": \"foo\", \"a\": \"bar\"}", "{\"b\": \"foo\", \"a\": \"bar\"}",
@ -964,11 +938,10 @@ static void test_json_schema() {
); );
// NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
test_grammar( test_schema(
"required props", "required props",
// Grammar // Schema
json_schema_to_grammar(nlohmann::ordered_json::parse( R"""(
R"""(
{ {
"$schema": "https://json-schema.org/draft/2020-12/schema", "$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://example.com/product.schema.json", "$id": "https://example.com/product.schema.json",
@ -1016,8 +989,7 @@ static void test_json_schema() {
}, },
"required": [ "productId", "productName", "price" ] "required": [ "productId", "productName", "price" ]
} }
)""" )""",
)),
// Passing strings // Passing strings
{ {
"{\"productId\": 1, \"productName\": \"A green door\", \"price\": 12.50}", "{\"productId\": 1, \"productName\": \"A green door\", \"price\": 12.50}",
@ -1040,8 +1012,6 @@ static void test_json_schema() {
} }
); );
} }
int main() { int main() {