From d63c953185b4e60f4487ed63f07328631a96dbb3 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 28 Jun 2024 00:55:35 +0100 Subject: [PATCH] `json`: fix nested `$ref`s & allow mix of properties & anyOf (https://github.com/ggerganov/llama.cpp/issues/8073) --- common/json-schema-to-grammar.cpp | 260 +++++++++--------- examples/json_schema_to_grammar.py | 174 ++++++------ examples/server/chat.mjs | 3 +- examples/server/public/index-new.html | 3 +- examples/server/public/index.html | 3 +- .../server/public/json-schema-to-grammar.mjs | 211 +++++++------- examples/server/themes/buttons-top/index.html | 3 +- examples/server/themes/wild/index.html | 3 +- grammars/README.md | 3 +- tests/run-json-schema-to-grammar.mjs | 3 +- tests/test-grammar-integration.cpp | 42 +++ tests/test-json-schema-to-grammar.cpp | 93 ++++--- 12 files changed, 427 insertions(+), 374 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 2f233e2e7..3e2a9772b 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -8,6 +8,7 @@ #include #include #include +#include using json = nlohmann::ordered_json; @@ -392,10 +393,10 @@ private: std::function _fetch_json; bool _dotall; std::map _rules; - std::unordered_map _refs; - std::unordered_set _refs_being_resolved; std::vector _errors; std::vector _warnings; + std::unordered_map _external_refs; + std::vector _ref_context; std::string _add_rule(const std::string & name, const std::string & rule) { std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-"); @@ -683,17 +684,6 @@ private: return out.str(); } - std::string _resolve_ref(const std::string & ref) { - std::string ref_name = ref.substr(ref.find_last_of('/') + 1); - if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { - _refs_being_resolved.insert(ref); - json resolved = _refs[ref]; - ref_name = visit(resolved, ref_name); - _refs_being_resolved.erase(ref); - } - return ref_name; - } - std::string _build_object_rule( const std::vector> & properties, const std::unordered_set & required, @@ -815,78 +805,79 @@ public: _rules["space"] = SPACE_RULE; } - void resolve_refs(json & schema, const std::string & url) { - /* - * Resolves all $ref fields in the given schema, fetching any remote schemas, - * replacing each $ref with absolute reference URL and populates _refs with the - * respective referenced (sub)schema dictionaries. - */ - std::function visit_refs = [&](json & n) { - if (n.is_array()) { - for (auto & x : n) { - visit_refs(x); - } - } else if (n.is_object()) { - if (n.contains("$ref")) { - std::string ref = n["$ref"]; - if (_refs.find(ref) == _refs.end()) { - json target; - if (ref.find("https://") == 0) { - std::string base_url = ref.substr(0, ref.find('#')); - auto it = _refs.find(base_url); - if (it != _refs.end()) { - target = it->second; - } else { - // Fetch the referenced schema and resolve its refs - auto referenced = _fetch_json(ref); - resolve_refs(referenced, base_url); - _refs[base_url] = referenced; - } - if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) { - return; - } - } else if (ref.find("#/") == 0) { - target = schema; - n["$ref"] = url + ref; - ref = url + ref; - } else { - _errors.push_back("Unsupported ref: " + ref); - return; - } - std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); - for (size_t i = 1; i < tokens.size(); ++i) { - std::string sel = tokens[i]; - if (target.is_null() || !target.contains(sel)) { - _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); - return; - } - target = target[sel]; - } - _refs[ref] = target; - } - } else { - for (auto & kv : n.items()) { - visit_refs(kv.value()); - } - } - } - }; - - visit_refs(schema); - } - std::string _generate_constant_rule(const json & value) { return format_literal(value.dump()); } + struct ResolvedRef { + json target; + std::string name; + bool is_local; + }; + + ResolvedRef _resolve_ref(const std::string & ref) { + auto parts = split(ref, "#"); + if (parts.size() != 2) { + _errors.push_back("Unsupported ref: " + ref); + return {json(), "", false}; + } + const auto & url = parts[0]; + json target; + bool is_local = url.empty(); + if (is_local) { + if (_ref_context.empty()) { + _errors.push_back("Error resolving ref " + ref + ": no context"); + return {json(), "", false}; + } + target = _ref_context.back(); + } else { + auto it = _external_refs.find(url); + if (it != _external_refs.end()) { + target = it->second; + } else { + // Fetch the referenced schema and resolve its refs + target = _fetch_json(url); + _external_refs[url] = target; + } + } + auto tokens = split(parts[1], "/"); + for (size_t i = 1; i < tokens.size(); ++i) { + const auto & sel = tokens[i]; + if (target.is_null() || !target.contains(sel)) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return {json(), "", false}; + } + target = target[sel]; + } + return {target, tokens.empty() ? "" : tokens[tokens.size() - 1], is_local}; + } + std::string visit(const json & schema, const std::string & name) { json schema_type = schema.contains("type") ? schema["type"] : json(); std::string schema_format = schema.contains("format") ? schema["format"].get() : ""; std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name; - if (schema.contains("$ref")) { - return _add_rule(rule_name, _resolve_ref(schema["$ref"])); + if (_ref_context.empty()) { + _ref_context.push_back(schema); + auto ret = visit(schema, name); + _ref_context.pop_back(); + return ret; + } + + if (schema.contains("$ref") && schema["$ref"].is_string()) { + const auto & ref = schema["$ref"].get(); + auto resolved = _resolve_ref(ref); + if (resolved.target.is_null()) { + return ""; + } + if (!resolved.is_local) { + _ref_context.push_back(resolved.target); + } + auto ret = visit(resolved.target, (name.empty() || resolved.name.empty()) ? name : resolved.name); + if (!resolved.is_local) { + _ref_context.pop_back(); + } + return ret; } else if (schema.contains("oneOf") || schema.contains("anyOf")) { std::vector alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get>() : schema["anyOf"].get>(); return _add_rule(rule_name, _generate_union_rule(name, alt_schemas)); @@ -906,55 +897,6 @@ public: enum_values.push_back(_generate_constant_rule(v)); } return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); - } else if ((schema_type.is_null() || schema_type == "object") - && (schema.contains("properties") || - (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { - std::unordered_set required; - if (schema.contains("required") && schema["required"].is_array()) { - for (const auto & item : schema["required"]) { - if (item.is_string()) { - required.insert(item.get()); - } - } - } - std::vector> properties; - if (schema.contains("properties")) { - for (const auto & prop : schema["properties"].items()) { - properties.emplace_back(prop.key(), prop.value()); - } - } - return _add_rule(rule_name, - _build_object_rule( - properties, required, name, - schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { - std::unordered_set required; - std::vector> properties; - std::string hybrid_name = name; - std::function add_component = [&](const json & comp_schema, bool is_required) { - if (comp_schema.contains("$ref")) { - add_component(_refs[comp_schema["$ref"]], is_required); - } else if (comp_schema.contains("properties")) { - for (const auto & prop : comp_schema["properties"].items()) { - properties.emplace_back(prop.key(), prop.value()); - if (is_required) { - required.insert(prop.key()); - } - } - } else { - // todo warning - } - }; - for (auto & t : schema["allOf"]) { - if (t.contains("anyOf")) { - for (auto & tt : t["anyOf"]) { - add_component(tt, false); - } - } else { - add_component(t, true); - } - } - return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; if (items.is_array()) { @@ -1005,8 +947,71 @@ public: _build_min_max_int(min_value, max_value, out); out << ") space"; return _add_rule(rule_name, out.str()); - } else if (schema.empty() || schema_type == "object") { - return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); + } else if ((schema_type.is_null() || schema_type == "object")) { + std::unordered_set required; + std::vector> properties; + auto is_explicit_object = schema_type == "object"; + json additional_properties; + if (schema.contains("additionalProperties")) { + is_explicit_object = true; + additional_properties = schema["additionalProperties"]; + } + if (schema.contains("properties") && schema["properties"].is_object()) { + is_explicit_object = true; + for (const auto & prop : schema["properties"].items()) { + if (prop.value().is_object()) { + properties.emplace_back(prop.key(), prop.value()); + } + } + } + if (schema.contains("required") && schema["required"].is_array()) { + for (const auto & item : schema["required"]) { + if (item.is_string()) { + required.insert(item.get()); + } + } + } + if (schema.contains("allOf") && schema["allOf"].is_array()) { + std::function add_component = [&](const json & comp_schema, bool is_required) { + if (comp_schema.contains("$ref") && comp_schema["$ref"].is_string()) { + auto resolved = _resolve_ref(comp_schema["$ref"].get()); + add_component(resolved.target, is_required); + } else if (comp_schema.contains("properties")) { + for (const auto & prop : comp_schema["properties"].items()) { + properties.emplace_back(prop.key(), prop.value()); + if (is_required) { + required.insert(prop.key()); + } + } + if (comp_schema.contains("additionalProperties")) { + if (additional_properties.is_null()) { + additional_properties = comp_schema["additionalProperties"]; + } else if (additional_properties != comp_schema["additionalProperties"]) { + _warnings.push_back("Inconsistent additionalProperties in allOf"); + } + } + } else { + // todo warning + } + }; + for (auto & t : schema["allOf"]) { + if (t.contains("anyOf")) { + for (auto & tt : t["anyOf"]) { + add_component(tt, false); + } + } else { + add_component(t, true); + } + } + } + if (properties.empty() && (additional_properties == true || additional_properties.is_null())) { + return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); + } + auto default_additional_properties = is_explicit_object ? json() : json(false); + return _add_rule(rule_name, + _build_object_rule( + properties, required, name, + additional_properties.is_null() ? default_additional_properties : additional_properties)); } else { if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) { _errors.push_back("Unrecognized schema: " + schema.dump()); @@ -1038,7 +1043,6 @@ public: std::string json_schema_to_grammar(const json & schema) { SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); auto copy = schema; - converter.resolve_refs(copy, "input"); converter.visit(copy, ""); converter.check_errors(); return converter.format_grammar(); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 92f6e3d47..02c91d650 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -243,8 +243,8 @@ class SchemaConverter: self._rules = { 'space': SPACE_RULE, } - self._refs = {} - self._refs_being_resolved = set() + self._external_refs = {} + self._ref_context = [] def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( @@ -332,51 +332,6 @@ class SchemaConverter: self._rules[key] = rule return key - def resolve_refs(self, schema: dict, url: str): - ''' - Resolves all $ref fields in the given schema, fetching any remote schemas, - replacing $ref with absolute reference URL and populating self._refs with the - respective referenced (sub)schema dictionaries. - ''' - def visit(n: dict): - if isinstance(n, list): - return [visit(x) for x in n] - elif isinstance(n, dict): - ref = n.get('$ref') - if ref is not None and ref not in self._refs: - if ref.startswith('https://'): - assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' - import requests - - frag_split = ref.split('#') - base_url = frag_split[0] - - target = self._refs.get(base_url) - if target is None: - target = self.resolve_refs(requests.get(ref).json(), base_url) - self._refs[base_url] = target - - if len(frag_split) == 1 or frag_split[-1] == '': - return target - elif ref.startswith('#/'): - target = schema - ref = f'{url}{ref}' - n['$ref'] = ref - else: - raise ValueError(f'Unsupported ref {ref}') - - for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] - - self._refs[ref] = target - else: - for v in n.values(): - visit(v) - - return n - return visit(schema) - def _generate_union_rule(self, name, alt_schemas): return ' | '.join(( self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') @@ -541,25 +496,59 @@ class SchemaConverter: else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") - def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] - if ref_name not in self._rules and ref not in self._refs_being_resolved: - self._refs_being_resolved.add(ref) - resolved = self._refs[ref] - ref_name = self.visit(resolved, ref_name) - self._refs_being_resolved.remove(ref) - return ref_name - def _generate_constant_rule(self, value): return self._format_literal(json.dumps(value)) + class ResolvedRef: + def __init__(self, target: Any, name: str, is_local: bool): + self.target = target + self.name = name + self.is_local = is_local + + def _resolve_ref(self, ref: str): + parts = ref.split('#') + assert len(parts) == 2, f'Unsupported ref: {ref}' + url = parts[0] + target = None + is_local = not url + if is_local: + assert self._ref_context, f'Error resolving ref {ref}: no context' + target = self._ref_context[-1] + else: + target = self._external_refs.get(url) + if target is None: + # Fetch the referenced schema and resolve its refs + target = self._fetch_json(url) + self._external_refs[url] = target + + tokens = parts[1].split('/') + for sel in tokens[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + return self.ResolvedRef(target, tokens[-1] if tokens else '', is_local) + def visit(self, schema, name): schema_type = schema.get('type') schema_format = schema.get('format') rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + if not self._ref_context: + self._ref_context.append(schema) + try: + return self.visit(schema, name) + finally: + self._ref_context.pop() + if (ref := schema.get('$ref')) is not None: - return self._add_rule(rule_name, self._resolve_ref(ref)) + resolved = self._resolve_ref(ref) + if not resolved.is_local: + self._ref_context.append(resolved.target) + try: + return self.visit(resolved.target, name if name == '' or resolved.name == '' else resolved.name) + finally: + if not resolved.is_local: + self._ref_context.pop() elif 'oneOf' in schema or 'anyOf' in schema: return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) @@ -574,36 +563,6 @@ class SchemaConverter: rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' return self._add_rule(rule_name, rule) - elif schema_type in (None, 'object') and \ - ('properties' in schema or \ - ('additionalProperties' in schema and schema['additionalProperties'] is not True)): - required = set(schema.get('required', [])) - properties = list(schema.get('properties', {}).items()) - return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) - - elif schema_type in (None, 'object') and 'allOf' in schema: - required = set() - properties = [] - hybrid_name = name - def add_component(comp_schema, is_required): - if (ref := comp_schema.get('$ref')) is not None: - comp_schema = self._refs[ref] - - if 'properties' in comp_schema: - for prop_name, prop_schema in comp_schema['properties'].items(): - properties.append((prop_name, prop_schema)) - if is_required: - required.add(prop_name) - - for t in schema['allOf']: - if 'anyOf' in t: - for tt in t['anyOf']: - add_component(tt, is_required=False) - else: - add_component(t, is_required=True) - - return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[])) - elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): items = schema.get('items') or schema['prefixItems'] if isinstance(items, list): @@ -658,8 +617,44 @@ class SchemaConverter: out.append(") space") return self._add_rule(rule_name, ''.join(out)) - elif (schema_type == 'object') or (len(schema) == 0): - return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + elif (schema_type == 'object') or (schema_type is None): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + is_explicit_object = schema_type == 'object' or 'properties' in schema or 'additionalProperties' in schema + additional_properties = schema.get('additionalProperties') + + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + resolved = self._resolve_ref(ref) + comp_schema = resolved.target + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + if 'additionalProperties' in comp_schema: + if additional_properties is None: + additional_properties = comp_schema['additionalProperties'] + elif additional_properties != comp_schema['additionalProperties']: + raise ValueError('Inconsistent additionalProperties in allOf') + + for t in schema.get('allOf', []): + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + if not properties and (additional_properties == True or additional_properties is None): + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + + default_additional_properties = None if is_explicit_object else False + return self._add_rule( + rule_name, + self._build_object_rule( + properties, required, name, + additional_properties if additional_properties is not None else default_additional_properties)) else: assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' @@ -800,7 +795,6 @@ def main(args_in = None): allow_fetch=args.allow_fetch, dotall=args.dotall, raw_pattern=args.raw_pattern) - schema = converter.resolve_refs(schema, url) converter.visit(schema, '') print(converter.format_grammar()) diff --git a/examples/server/chat.mjs b/examples/server/chat.mjs index a79c8a3cd..e7d15442c 100644 --- a/examples/server/chat.mjs +++ b/examples/server/chat.mjs @@ -26,9 +26,8 @@ const propOrder = grammarJsonSchemaPropOrder let grammar = null if (grammarJsonSchemaFile) { - let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8')) + const schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8')) const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true}) - schema = await converter.resolveRefs(schema, grammarJsonSchemaFile) converter.visit(schema, '') grammar = converter.formatGrammar() } diff --git a/examples/server/public/index-new.html b/examples/server/public/index-new.html index 5513e9121..f909b7372 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public/index-new.html @@ -558,14 +558,13 @@ const ConfigForm = (props) => { const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value const convertJSONSchemaGrammar = async () => { try { - let schema = JSON.parse(params.value.grammar) + const schema = JSON.parse(params.value.grammar) const converter = new SchemaConverter({ prop_order: grammarJsonSchemaPropOrder.value .split(',') .reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}), allow_fetch: true, }) - schema = await converter.resolveRefs(schema, 'input') converter.visit(schema, '') params.value = { ...params.value, diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 2f60a76e8..d3865fef7 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -707,14 +707,13 @@ const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value const convertJSONSchemaGrammar = async () => { try { - let schema = JSON.parse(params.value.grammar) + const schema = JSON.parse(params.value.grammar) const converter = new SchemaConverter({ prop_order: grammarJsonSchemaPropOrder.value .split(',') .reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}), allow_fetch: true, }) - schema = await converter.resolveRefs(schema, 'input') converter.visit(schema, '') params.value = { ...params.value, diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs index 06d76edde..8baf23c2c 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -268,7 +268,8 @@ export class SchemaConverter { this._dotall = options.dotall || false; this._rules = {'space': SPACE_RULE}; this._refs = {}; - this._refsBeingResolved = new Set(); + this._externalRefs = new Map(); + this._refContext = []; } _formatLiteral(literal) { @@ -306,60 +307,6 @@ export class SchemaConverter { return key; } - async resolveRefs(schema, url) { - const visit = async (n) => { - if (Array.isArray(n)) { - return Promise.all(n.map(visit)); - } else if (typeof n === 'object' && n !== null) { - let ref = n.$ref; - let target; - if (ref !== undefined && !this._refs[ref]) { - if (ref.startsWith('https://')) { - if (!this._allowFetch) { - throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)'); - } - const fetch = (await import('node-fetch')).default; - - const fragSplit = ref.split('#'); - const baseUrl = fragSplit[0]; - - target = this._refs[baseUrl]; - if (!target) { - target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl); - this._refs[baseUrl] = target; - } - - if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') { - return target; - } - } else if (ref.startsWith('#/')) { - target = schema; - ref = `${url}${ref}`; - n.$ref = ref; - } else { - throw new Error(`Unsupported ref ${ref}`); - } - - const selectors = ref.split('#')[1].split('/').slice(1); - for (const sel of selectors) { - if (!target || !(sel in target)) { - throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); - } - target = target[sel]; - } - - this._refs[ref] = target; - } else { - await Promise.all(Object.values(n).map(visit)); - } - } - - return n; - }; - - return visit(schema); - } - _generateUnionRule(name, altSchemas) { return altSchemas .map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`)) @@ -590,29 +537,69 @@ export class SchemaConverter { return out.join(''); } - _resolveRef(ref) { - let refName = ref.split('/').pop(); - if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { - this._refsBeingResolved.add(ref); - const resolved = this._refs[ref]; - refName = this.visit(resolved, refName); - this._refsBeingResolved.delete(ref); - } - return refName; - } - _generateConstantRule(value) { return this._formatLiteral(JSON.stringify(value)); } + _resolveRef(ref) { + const parts = ref.split('#'); + if (parts.length !== 2) { + throw new Error(`Unsupported ref: ${ref}`); + } + const url = parts[0]; + let target = null; + let isLocal = !url; + if (isLocal) { + if (this._refContext.length === 0) { + throw new Error(`Error resolving ref ${ref}: no context`); + } + target = this._refContext[this._refContext.length - 1]; + } else { + target = this._externalRefs.get(url); + if (target === undefined) { + // Fetch the referenced schema and resolve its refs + target = this._fetchJson(url); + this._externalRefs.set(url, target); + } + } + const tokens = parts[1].split('/'); + for (const sel of tokens.slice(1)) { + if (target === null || !(sel in target)) { + throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); + } + target = target[sel]; + } + const name = tokens[tokens.length - 1] || ''; + return {target, name, isLocal}; + } + visit(schema, name) { const schemaType = schema.type; const schemaFormat = schema.format; const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name; + if (this._refContext.length === 0) { + this._refContext.push(schema); + try { + return this.visit(schema, name); + } finally { + this._refContext.pop(); + } + } + const ref = schema.$ref; if (ref !== undefined) { - return this._addRule(ruleName, this._resolveRef(ref)); + const resolved = this._resolveRef(ref); + if (!resolved.isLocal) { + this._refContext.push(resolved.target); + } + try { + return this.visit(resolved.target, name === '' || resolved.name === '' ? name : resolved.name); + } finally { + if (!resolved.isLocal) { + this._refContext.pop(); + } + } } else if (schema.oneOf || schema.anyOf) { return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf)); } else if (Array.isArray(schemaType)) { @@ -622,42 +609,6 @@ export class SchemaConverter { } else if ('enum' in schema) { const rule = '(' + schema.enum.map(v => this._generateConstantRule(v)).join(' | ') + ') space'; return this._addRule(ruleName, rule); - } else if ((schemaType === undefined || schemaType === 'object') && - ('properties' in schema || - ('additionalProperties' in schema && schema.additionalProperties !== true))) { - const required = new Set(schema.required || []); - const properties = Object.entries(schema.properties ?? {}); - return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties)); - } else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) { - const required = new Set(); - const properties = []; - const addComponent = (compSchema, isRequired) => { - const ref = compSchema.$ref; - if (ref !== undefined) { - compSchema = this._refs[ref]; - } - - if ('properties' in compSchema) { - for (const [propName, propSchema] of Object.entries(compSchema.properties)) { - properties.push([propName, propSchema]); - if (isRequired) { - required.add(propName); - } - } - } - }; - - for (const t of schema.allOf) { - if ('anyOf' in t) { - for (const tt of t.anyOf) { - addComponent(tt, false); - } - } else { - addComponent(t, true); - } - } - - return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null)); } else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) { const items = schema.items ?? schema.prefixItems; if (Array.isArray(items)) { @@ -706,8 +657,58 @@ export class SchemaConverter { _generateMinMaxInt(minValue, maxValue, out); out.push(") space"); return this._addRule(ruleName, out.join('')); - } else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) { - return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object'])); + } else if (schemaType === undefined || schemaType === 'object') { + const required = new Set(schema.required || []); + const properties = Object.entries(schema.properties ?? {}); + const isExplicitObject = schemaType === 'object' || 'properties' in schema || 'additionalProperties' in schema; + let additionalProperties = schema.additionalProperties; + + const addComponent = (compSchema, isRequired) => { + const ref = compSchema.$ref; + if (ref !== undefined) { + const resolved = this._resolveRef(ref); + compSchema = resolved.target; + } + + if ('properties' in compSchema) { + for (const [propName, propSchema] of Object.entries(compSchema.properties)) { + properties.push([propName, propSchema]); + if (isRequired) { + required.add(propName); + } + } + if ('additionalProperties' in compSchema) { + if (additionalProperties === null) { + additionalProperties = compSchema.additionalProperties; + } else if (additionalProperties !== compSchema.additionalProperties) { + throw new Error('Inconsistent additionalProperties in allOf'); + } + } + } + }; + + if ('allOf' in schema) { + for (const t of schema.allOf) { + if ('anyOf' in t) { + for (const tt of t.anyOf) { + addComponent(tt, false); + } + } else { + addComponent(t, true); + } + } + } + + if (properties.length === 0 && (additionalProperties === true || additionalProperties == null)) { + return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object'])); + } + + const defaultAdditionalProperties = isExplicitObject ? null : false; + return this._addRule( + ruleName, + this._buildObjectRule(properties, required, name, additionalProperties ?? defaultAdditionalProperties) + ); + } else { if (!(schemaType in PRIMITIVE_RULES)) { throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`); diff --git a/examples/server/themes/buttons-top/index.html b/examples/server/themes/buttons-top/index.html index 6af30d307..b25ff7a93 100644 --- a/examples/server/themes/buttons-top/index.html +++ b/examples/server/themes/buttons-top/index.html @@ -634,14 +634,13 @@ const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value const convertJSONSchemaGrammar = async () => { try { - let schema = JSON.parse(params.value.grammar) + const schema = JSON.parse(params.value.grammar) const converter = new SchemaConverter({ prop_order: grammarJsonSchemaPropOrder.value .split(',') .reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}), allow_fetch: true, }) - schema = await converter.resolveRefs(schema, 'input') converter.visit(schema, '') params.value = { ...params.value, diff --git a/examples/server/themes/wild/index.html b/examples/server/themes/wild/index.html index 772e716cd..5490fe62f 100644 --- a/examples/server/themes/wild/index.html +++ b/examples/server/themes/wild/index.html @@ -637,14 +637,13 @@ const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value const convertJSONSchemaGrammar = async () => { try { - let schema = JSON.parse(params.value.grammar) + const schema = JSON.parse(params.value.grammar) const converter = new SchemaConverter({ prop_order: grammarJsonSchemaPropOrder.value .split(',') .reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}), allow_fetch: true, }) - schema = await converter.resolveRefs(schema, 'input') converter.visit(schema, '') params.value = { ...params.value, diff --git a/grammars/README.md b/grammars/README.md index 40f666240..6920ca0bc 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -183,10 +183,9 @@ space ::= | " " | "\n" [ \t]{0,20} Here is also a list of known limitations (contributions welcome): - Unsupported features are skipped silently. It is currently advised to use the command-line Python converter (see above) to see any warnings, and to inspect the resulting grammar / test it w/ [llama-gbnf-validator](../examples/gbnf-validator/gbnf-validator.cpp). -- Can't mix `properties` w/ `anyOf` / `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703) +- Can't mix `properties` w/ `oneOf` in the same type (https://github.com/ggerganov/llama.cpp/issues/7703) - [prefixItems](https://json-schema.org/draft/2020-12/json-schema-core#name-prefixitems) is broken (but [items](https://json-schema.org/draft/2020-12/json-schema-core#name-items) works) - `minimum`, `exclusiveMinimum`, `maximum`, `exclusiveMaximum`: only supported for `"type": "integer"` for now, not `number` -- Nested `$ref`s are broken (https://github.com/ggerganov/llama.cpp/issues/8073) - [pattern](https://json-schema.org/draft/2020-12/json-schema-validation#name-pattern)s must start with `^` and end with `$` - Remote `$ref`s not supported in the C++ version (Python & JavaScript versions fetch https refs) - `string` [formats](https://json-schema.org/draft/2020-12/json-schema-validation#name-defined-formats) lack `uri`, `email` diff --git a/tests/run-json-schema-to-grammar.mjs b/tests/run-json-schema-to-grammar.mjs index 71bf62ed3..8d020a02e 100644 --- a/tests/run-json-schema-to-grammar.mjs +++ b/tests/run-json-schema-to-grammar.mjs @@ -3,8 +3,7 @@ import { SchemaConverter } from "../examples/server/public/json-schema-to-gramma const [, , file] = process.argv const url = `file://${file}` -let schema = JSON.parse(readFileSync(file, "utf8")); +const schema = JSON.parse(readFileSync(file, "utf8")); const converter = new SchemaConverter({}) -schema = await converter.resolveRefs(schema, url) converter.visit(schema, '') console.log(converter.formatGrammar()) diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 0e21dc795..dd277e3be 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -1266,6 +1266,48 @@ static void test_json_schema() { // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""", } ); + + test_schema( + "refs", + // Schema + R"""({ + "type": "array", + "minItems": 1, + "maxItems": 15, + "items": { "$ref": "#/$defs/TALK" }, + + "$defs": { + "characters": { "enum": ["Biff", "Alice"] }, + "emotes": { "enum": ["EXCLAMATION", "CONFUSION", "CHEERFUL", "LOVE", "ANGRY"] }, + + "TALK": { + "type": "object", + "required": [ "character", "emote", "dialog" ], + "properties": { + "character": { "$ref": "#/$defs/characters" }, + "emote": { "$ref": "#/$defs/emotes" }, + "dialog": { + "type": "string", + "minLength": 1, + "maxLength": 200 + } + }, + "additionalProperties": false + } + } + })""", + // Passing strings + { + R"""([{ + "character": "Alice", + "emote": "EXCLAMATION", + "dialog": "Hello, world!" + }])""", + }, + // Failing strings + { + } + ); } int main() { diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 3aaa11833..ec57dfc83 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -344,6 +344,48 @@ static void test_all(const std::string & lang, std::function