json: check parsing in test + fix value & string refs

This commit is contained in:
ochafik 2024-03-17 22:47:20 +00:00
parent 84e383c1d7
commit 3e1bf44e5e
7 changed files with 1582 additions and 1517 deletions

View file

@ -664,8 +664,6 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
json-schema-to-grammar.o: examples/server/json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.h json-schema-to-grammar.o: examples/server/json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
# $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) -DLLAMA_BUILD_JSON_SCHEMA_CONVERTER=1
# $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
train.o: common/train.cpp common/train.h train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
@ -862,7 +860,7 @@ tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.cpp examples/server/json-schema-to-grammar.h tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp ggml.o llama.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -24,6 +24,7 @@ PRIMITIVE_RULES = {
)* "\"" space''', )* "\"" space''',
'null': '"null" space', 'null': '"null" space',
} }
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']
# TODO: support "uri", "email" string formats # TODO: support "uri", "email" string formats
DATE_RULES = { DATE_RULES = {
@ -384,11 +385,10 @@ class SchemaConverter:
elif schema_type in (None, 'string') and 'pattern' in schema: elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name) return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type == 'object' and len(schema) == 1 or schema_type is None and len(schema) == 0: elif (schema_type == 'object' and len(schema) == 1) or (len(schema) == 0):
# This depends on all primitive types for n in OBJECT_RULE_NAMES:
for t, r in PRIMITIVE_RULES.items(): self._add_rule(n, PRIMITIVE_RULES[n])
self._add_rule(t, r) return self._add_rule(rule_name, 'object')
return 'object'
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_rule( return self._add_rule(
@ -427,7 +427,10 @@ class SchemaConverter:
if additional_properties: if additional_properties:
sub_name = f'{name}{"-" if name else ""}additional' sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit(additional_properties, f'{sub_name}-value') value_rule = self.visit(additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(f'{sub_name}-kv', f'string ":" space {value_rule}') prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*") optional_props.append("*")
rule = '"{" space ' rule = '"{" space '

View file

@ -33,7 +33,7 @@ unordered_map<string, string> PRIMITIVE_RULES = {
" )* \"\\\"\" space"}, " )* \"\\\"\" space"},
{"null", "\"null\" space"} {"null", "\"null\" space"}
}; };
vector<string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null"}; vector<string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null", "value"};
unordered_map<string, string> DATE_RULES = { unordered_map<string, string> DATE_RULES = {
{"date", "[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( [0-2] [0-9] | \"3\" [0-1] )"}, {"date", "[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( [0-2] [0-9] | \"3\" [0-1] )"},
@ -416,7 +416,7 @@ private:
if (additional_properties.is_object()) { if (additional_properties.is_object()) {
string sub_name = name + (name.empty() ? "" : "-") + "additional"; string sub_name = name + (name.empty() ? "" : "-") + "additional";
string value_rule = visit(additional_properties, sub_name + "-value"); string value_rule = visit(additional_properties, sub_name + "-value");
string kv_rule = _add_rule(sub_name + "-kv", "string \":\" space " + value_rule); string kv_rule = _add_rule(sub_name + "-kv", _add_rule("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule; prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*"); optional_props.push_back("*");
} }

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -15,7 +15,7 @@ const PRIMITIVE_RULES = {
)* "\\"" space`, )* "\\"" space`,
null: '"null" space', null: '"null" space',
}; };
const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null']; const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'];
// TODO: support "uri", "email" string formats // TODO: support "uri", "email" string formats
const DATE_RULES = { const DATE_RULES = {
@ -467,7 +467,9 @@ export class SchemaConverter {
if (typeof additionalProperties === 'object') { if (typeof additionalProperties === 'object') {
const subName = `${name ?? ''}${name ? '-' : ''}additional`; const subName = `${name ?? ''}${name ? '-' : ''}additional`;
const valueRule = this.visit(additionalProperties, `${subName}-value`); const valueRule = this.visit(additionalProperties, `${subName}-value`);
propKvRuleNames['*'] = this._addRule(`${subName}-kv`, `string ":" space ${valueRule}`); propKvRuleNames['*'] = this._addRule(
`${subName}-kv`,
`${this._addRule('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`);
optionalProps.push('*'); optionalProps.push('*');
} }

View file

@ -11,6 +11,8 @@
#include "../examples/server/json-schema-to-grammar.h" #include "../examples/server/json-schema-to-grammar.h"
#include "../examples/server/json-schema-to-grammar.cpp" #include "../examples/server/json-schema-to-grammar.cpp"
#include "grammar-parser.h"
using namespace std; using namespace std;
static std::string trim(const std::string & source) { static std::string trim(const std::string & source) {
@ -30,7 +32,7 @@ struct TestCase {
string schema; string schema;
string expected; string expected;
void verify(const string& actual) const { void verify(const string& actual, bool parse = false) const {
if (trim(actual) != trim(expected)) { if (trim(actual) != trim(expected)) {
cerr << "#" << endl; cerr << "#" << endl;
cerr << "# Test '" << name.c_str() << "' failed." << endl; cerr << "# Test '" << name.c_str() << "' failed." << endl;
@ -40,6 +42,10 @@ struct TestCase {
cerr << "# ACTUAL:\n" << actual.c_str() << endl; cerr << "# ACTUAL:\n" << actual.c_str() << endl;
assert(false); assert(false);
} }
if (parse) {
auto state = grammar_parser::parse(actual.c_str());
assert(state.symbol_ids.find("root") != state.symbol_ids.end());
}
} }
void verify_status(TestCaseStatus status) const { void verify_status(TestCaseStatus status) const {
if (status != expected_status) { if (status != expected_status) {
@ -108,6 +114,7 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
[^"\\] | [^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space )* "\"" space
value ::= object | array | string | number | boolean
)""" )"""
}); });
@ -453,6 +460,10 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (additional-kvs )? "}" space root ::= "{" space (additional-kvs )? "}" space
space ::= " "? space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)""" )"""
}); });
@ -499,6 +510,10 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
space ::= " "? space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)""" )"""
}); });
@ -523,6 +538,10 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
space ::= " "? space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)""" )"""
}); });
@ -606,6 +625,16 @@ static void test_all(const string& lang, std::function<void(const TestCase&)> ru
} }
int main() { int main() {
test_all("C++", [](const TestCase& tc) {
try {
// We only try and parse the grammar output in the C++ test.
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)), /* parse= */ true);
tc.verify_status(SUCCESS);
} catch (const runtime_error& ex) {
cerr << "Error: " << ex.what() << endl;
tc.verify_status(FAILURE);
}
});
test_all("Python", [](const TestCase& tc) { test_all("Python", [](const TestCase& tc) {
write("test-json-schema-input.tmp", tc.schema); write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system( tc.verify_status(std::system(
@ -618,13 +647,4 @@ int main() {
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE); "node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp")); tc.verify(read("test-grammar-output.tmp"));
}); });
test_all("C++", [](const TestCase& tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
tc.verify_status(SUCCESS);
} catch (const runtime_error& ex) {
cerr << "Error: " << ex.what() << endl;
tc.verify_status(FAILURE);
}
});
} }