Merging improved schema test methods added by @ochafik in #7797

This commit is contained in:
HanClinto 2024-06-12 10:42:37 -07:00 committed by Clint Herron
parent acd3c468af
commit d4a63b0538

View file

@ -13,6 +13,8 @@
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
@ -66,8 +68,8 @@ static bool match_string(const std::string & input, llama_grammar* grammar) {
return false;
}
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
static void test(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr);
auto grammar = build_grammar(grammar_str);
@ -136,6 +138,12 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
// Clean up allocated memory
llama_grammar_free(grammar);
}
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
}
static void test_simple_grammar() {
// Test case for a simple grammar
@ -491,14 +499,12 @@ static void test_json_schema() {
// but we convert each json schema to a grammar before parsing.
// Otherwise, this test structure is the same.
test_grammar(
test_schema(
"empty schema (object)",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{}
)"""
)),
)""",
// Passing strings
{
"{}",
@ -510,10 +516,9 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"exotic formats (list)",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"items": [
@ -523,8 +528,7 @@ static void test_json_schema() {
{ "format": "date-time" }
]
}
)"""
)),
)""",
// Passing strings
{
// "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
@ -540,16 +544,14 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"string",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "string"
}
)"""
)),
)""",
// Passing strings
{
"\"foo\"",
@ -563,17 +565,15 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"string w/ min length 1",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "string",
"minLength": 1
}
)"""
)),
)""",
// Passing strings
{
"\"foo\"",
@ -587,17 +587,15 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"string w/ min length 3",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "string",
"minLength": 3
}
)"""
)),
)""",
// Passing strings
{
"\"foo\"",
@ -612,17 +610,15 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"string w/ max length",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "string",
"maxLength": 3
}
)"""
)),
)""",
// Passing strings
{
"\"foo\"",
@ -637,18 +633,16 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"string w/ min & max length",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "string",
"minLength": 1,
"maxLength": 4
}
)"""
)),
)""",
// Passing strings
{
"\"foo\"",
@ -664,16 +658,14 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"boolean",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "boolean"
}
)"""
)),
)""",
// Passing strings
{
"true",
@ -688,16 +680,14 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"integer",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "integer"
}
)"""
)),
)""",
// Passing strings
{
"0",
@ -713,16 +703,14 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"string const",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"const": "foo"
}
)"""
)),
)""",
// Passing strings
{
"\"foo\"",
@ -734,16 +722,14 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"non-string const",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"const": true
}
)"""
)),
)""",
// Passing strings
{
"true",
@ -756,16 +742,14 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"non-string const",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"enum": ["red", "amber", "green", null, 42, ["foo"]]
}
)"""
)),
)""",
// Passing strings
{
"\"red\"",
@ -783,10 +767,9 @@ static void test_json_schema() {
);
test_grammar(
test_schema(
"min+max items",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"items": {
@ -795,8 +778,7 @@ static void test_json_schema() {
"minItems": 3,
"maxItems": 5
}
)"""
)),
)""",
// Passing strings
{
"[1, 2, 3]",
@ -812,10 +794,9 @@ static void test_json_schema() {
);
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_grammar(
test_schema(
"object properties",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "object",
@ -825,8 +806,7 @@ static void test_json_schema() {
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
}
}
)"""
)),
)""",
// Passing strings
{
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
@ -854,10 +834,9 @@ static void test_json_schema() {
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_grammar(
test_schema(
"object properties, additionalProperties: true",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "object",
@ -868,8 +847,7 @@ static void test_json_schema() {
},
"additionalProperties": true
}
)"""
)),
)""",
// Passing strings
{
// TODO: Following line should pass and doesn't
@ -897,10 +875,9 @@ static void test_json_schema() {
);
// Additional properties: false
test_grammar(
test_schema(
"required + optional props each in original order",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"type": "object",
@ -911,8 +888,7 @@ static void test_json_schema() {
},
"additionalProperties": false
}
)"""
)),
)""",
// Passing strings
{
R"""({ "street_name": "Pennsylvania" })""",
@ -931,10 +907,9 @@ static void test_json_schema() {
}
);
test_grammar(
test_schema(
"required + optional props each in original order",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"properties": {
@ -946,8 +921,7 @@ static void test_json_schema() {
"required": ["a", "b"],
"additionalProperties": false
}
)"""
)),
)""",
// Passing strings
{
"{\"b\": \"foo\", \"a\": \"bar\"}",
@ -964,10 +938,9 @@ static void test_json_schema() {
);
// NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
test_grammar(
test_schema(
"required props",
// Grammar
json_schema_to_grammar(nlohmann::ordered_json::parse(
// Schema
R"""(
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
@ -1016,8 +989,7 @@ static void test_json_schema() {
},
"required": [ "productId", "productName", "price" ]
}
)"""
)),
)""",
// Passing strings
{
"{\"productId\": 1, \"productName\": \"A green door\", \"price\": 12.50}",
@ -1040,8 +1012,6 @@ static void test_json_schema() {
}
);
}
int main() {