json: check parsing in test + fix value & string refs

This commit is contained in:
ochafik 2024-03-17 22:47:20 +00:00
parent 84e383c1d7
commit 3e1bf44e5e
7 changed files with 1582 additions and 1517 deletions

View file

@ -664,8 +664,6 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
json-schema-to-grammar.o: examples/server/json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.h
$(CXX) $(CXXFLAGS) -c $< -o $@
# $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) -DLLAMA_BUILD_JSON_SCHEMA_CONVERTER=1
# $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@
@ -862,7 +860,7 @@ tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.h
tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp ggml.o llama.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -24,6 +24,7 @@ PRIMITIVE_RULES = {
)* "\"" space''',
'null': '"null" space',
}
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']
# TODO: support "uri", "email" string formats
DATE_RULES = {
@ -384,11 +385,10 @@ class SchemaConverter:
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type == 'object' and len(schema) == 1 or schema_type is None and len(schema) == 0:
# This depends on all primitive types
for t, r in PRIMITIVE_RULES.items():
self._add_rule(t, r)
return 'object'
elif (schema_type == 'object' and len(schema) == 1) or (len(schema) == 0):
for n in OBJECT_RULE_NAMES:
self._add_rule(n, PRIMITIVE_RULES[n])
return self._add_rule(rule_name, 'object')
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_rule(
@ -427,7 +427,10 @@ class SchemaConverter:
if additional_properties:
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit(additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(f'{sub_name}-kv', f'string ":" space {value_rule}')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")
rule = '"{" space '

View file

@ -33,7 +33,7 @@ unordered_map<string, string> PRIMITIVE_RULES = {
" )* \"\\\"\" space"},
{"null", "\"null\" space"}
};
vector<string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null"};
vector<string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null", "value"};
unordered_map<string, string> DATE_RULES = {
{"date", "[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( [0-2] [0-9] | \"3\" [0-1] )"},
@ -416,7 +416,7 @@ private:
if (additional_properties.is_object()) {
string sub_name = name + (name.empty() ? "" : "-") + "additional";
string value_rule = visit(additional_properties, sub_name + "-value");
string kv_rule = _add_rule(sub_name + "-kv", "string \":\" space " + value_rule);
string kv_rule = _add_rule(sub_name + "-kv", _add_rule("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*");
}

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -15,7 +15,7 @@ const PRIMITIVE_RULES = {
)* "\\"" space`,
null: '"null" space',
};
const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null'];
const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'];
// TODO: support "uri", "email" string formats
const DATE_RULES = {
@ -467,7 +467,9 @@ export class SchemaConverter {
if (typeof additionalProperties === 'object') {
const subName = `${name ?? ''}${name ? '-' : ''}additional`;
const valueRule = this.visit(additionalProperties, `${subName}-value`);
propKvRuleNames['*'] = this._addRule(`${subName}-kv`, `string ":" space ${valueRule}`);
propKvRuleNames['*'] = this._addRule(
`${subName}-kv`,
`${this._addRule('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`);
optionalProps.push('*');
}

View file

@ -11,6 +11,8 @@
#include "../examples/server/json-schema-to-grammar.h"
#include "../examples/server/json-schema-to-grammar.cpp"
#include "grammar-parser.h"
using namespace std;
static std::string trim(const std::string & source) {
@ -30,7 +32,7 @@ struct TestCase {
string schema;
string expected;
void verify(const string& actual) const {
void verify(const string& actual, bool parse = false) const {
if (trim(actual) != trim(expected)) {
cerr << "#" << endl;
cerr << "# Test '" << name.c_str() << "' failed." << endl;
@ -40,6 +42,10 @@ struct TestCase {
cerr << "# ACTUAL:\n" << actual.c_str() << endl;
assert(false);
}
if (parse) {
auto state = grammar_parser::parse(actual.c_str());
assert(state.symbol_ids.find("root") != state.symbol_ids.end());
}
}
void verify_status(TestCaseStatus status) const {
if (status != expected_status) {
@ -108,6 +114,7 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
@ -453,6 +460,10 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (additional-kvs )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
@ -499,6 +510,10 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
@ -523,6 +538,10 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
@ -606,6 +625,16 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
}
int main() {
test_all("C++", [](const TestCase& tc) {
try {
// We only try and parse the grammar output in the C++ test.
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)), /* parse= */ true);
tc.verify_status(SUCCESS);
} catch (const runtime_error& ex) {
cerr << "Error: " << ex.what() << endl;
tc.verify_status(FAILURE);
}
});
test_all("Python", [](const TestCase& tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
@ -618,13 +647,4 @@ int main() {
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp"));
});
test_all("C++", [](const TestCase& tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
tc.verify_status(SUCCESS);
} catch (const runtime_error& ex) {
cerr << "Error: " << ex.what() << endl;
tc.verify_status(FAILURE);
}
});
}