diff --git a/examples/json-schema-to-grammar.py b/examples/json-schema-to-grammar.py index ce090dd59..bedeb18d1 100755 --- a/examples/json-schema-to-grammar.py +++ b/examples/json-schema-to-grammar.py @@ -33,7 +33,6 @@ class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {'space': SPACE_RULE} - self.ref_base = None def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( @@ -53,20 +52,44 @@ class SchemaConverter: self._rules[key] = rule return key - def _resolve_ref(self, ref): - # TODO: use https://github.com/APIDevTools/json-schema-ref-parser - try: - if ref is not None and ref.startswith('#/'):# and 'definitions' in schema: - target = self.ref_base - name = None - for sel in ref.split('/')[1:]: - name = sel - target = target[sel] - return (name, target) - return None - except KeyError as e: - raise Exception(f'Error resolving ref {ref}: {e}') from e - + @staticmethod + def resolve_refs(schema: dict, url_cache: Dict[str, dict] = None): + if url_cache is None: + url_cache = {} + + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get('$ref') + if ref is not None: + if ref.startswith('#/'): + target = schema + for sel in ref.split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + return target + elif ref.startswith('https://'): + import requests + linked_schema = url_cache.get(ref) + if linked_schema is None: + linked_schema = requests.get(ref).json() + url_cache[ref] = linked_schema + return SchemaConverter.resolve_refs(linked_schema, url_cache) + else: + raise ValueError(f'Unsupported ref {ref}') + else: + for k in n.keys(): + v = n[k] + vv = visit(v) + if vv is not v: + n[k] = vv + else: + pass + return n + return visit(schema) + def _generate_union_rule(self, name, alt_schemas): return ' | '.join(( self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') @@ -163,17 +186,8 @@ class SchemaConverter: raise Exception(f'Error processing pattern: {pattern}: {e}') from e def visit(self, schema, name): - old_ref_base = self.ref_base - if 'definitions' in schema: - self.ref_base = schema - try: - return self._visit(schema, name) - finally: - self.ref_base = old_ref_base - - def _visit(self, schema, name): + assert '$ref' not in schema, f'Unresolved $ref in {schema} (make sure to use {SchemaConverter.resolve_refs.__name__})' schema_type = schema.get('type') - ref = schema.get('$ref') rule_name = name or 'root' if 'oneOf' in schema or 'anyOf' in schema: @@ -198,9 +212,6 @@ class SchemaConverter: required = set() properties = [] def add_component(comp_schema, is_required): - ref = comp_schema.get('$ref') - if ref is not None and (resolved := self._resolve_ref(ref)) is not None: - comp_schema = resolved[1] if 'properties' in comp_schema: for prop_name, prop_schema in comp_schema['properties'].items(): @@ -262,16 +273,6 @@ class SchemaConverter: elif schema_type in (None, 'string') and 'pattern' in schema: return self._add_rule(rule_name, self._visit_pattern(schema['pattern'], rule_name)) - elif (resolved := self._resolve_ref(ref)) is not None: - (ref_name, definition) = resolved - def_name = f'{name}-{ref_name}' if name else '' - return self.visit(definition, def_name) - - elif ref is not None and ref.startswith('https://'): - import requests - ref_schema = requests.get(ref).json() - return self.visit(ref_schema, ref) - elif schema_type == 'object' and len(schema) == 1 or schema_type is None and len(schema) == 0: # return 'object' for t, r in PRIMITIVE_RULES.items(): @@ -286,9 +287,7 @@ class SchemaConverter: ) def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str): - # TODO: `required` keyword prop_order = self._prop_order - print(f'# properties: {properties}', file=sys.stderr) # sort by position in prop_order (if specified) then by original order sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] @@ -370,6 +369,7 @@ def main(args_in = None): else: with open(args.schema) as f: schema = json.load(f) + schema = SchemaConverter.resolve_refs(schema) prop_order = {name: idx for idx, name in enumerate(args.prop_order)} converter = SchemaConverter(prop_order) converter.visit(schema, '')