From ba57964f928a74fef5a3f7d766e214e0f2d6a72a Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 10 Mar 2024 14:42:39 +0000 Subject: [PATCH] Update json-schema-to-grammar.py --- examples/json-schema-to-grammar.py | 95 +++++++++++++++++++----------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/examples/json-schema-to-grammar.py b/examples/json-schema-to-grammar.py index 21bfd1ffc..8c8017d58 100755 --- a/examples/json-schema-to-grammar.py +++ b/examples/json-schema-to-grammar.py @@ -33,6 +33,8 @@ class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {'space': SPACE_RULE} + self._refs = {} + self._refs_being_resolved = set() def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( @@ -52,41 +54,47 @@ class SchemaConverter: self._rules[key] = rule return key - @staticmethod - def resolve_refs(schema: dict, url_cache: Dict[str, dict] = None): - if url_cache is None: - url_cache = {} - + def resolve_refs(self, schema: dict, url: str): + ''' + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + ''' def visit(n: dict): if isinstance(n, list): return [visit(x) for x in n] elif isinstance(n, dict): ref = n.get('$ref') - if ref is not None: - if ref.startswith('#/'): - target = schema - for sel in ref.split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] - - return target - elif ref.startswith('https://'): + if ref is not None and ref not in self._refs: + if ref.startswith('https://'): import requests - linked_schema = url_cache.get(ref) - if linked_schema is None: - linked_schema = requests.get(ref).json() - url_cache[ref] = linked_schema - return SchemaConverter.resolve_refs(linked_schema, url_cache) + + frag_split = ref.split('#') + base_url = frag_split[0] + + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs(requests.get(ref).json(), base_url) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == '': + return + elif ref.startswith('#/'): + target = schema + ref = f'{url}{ref}' + n['$ref'] = ref else: raise ValueError(f'Unsupported ref {ref}') + + for sel in ref.split('#')[-1].split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + self._refs[ref] = target else: - for k in n.keys(): - v = n[k] - vv = visit(v) - if vv is not v: - n[k] = vv - else: - pass + for v in n.values(): + visit(v) + return n return visit(schema) @@ -96,7 +104,6 @@ class SchemaConverter: for i, alt_schema in enumerate(alt_schemas) )) - def _format_range_char(self, c): if c in ('-', ']', '\\'): return '\\' + chr(c) @@ -185,12 +192,23 @@ class SchemaConverter: except BaseException as e: raise Exception(f'Error processing pattern: {pattern}: {e}') from e + def _resolve_ref(self, ref): + ref_name = ref.split('/')[-1] + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + def visit(self, schema, name): - assert '$ref' not in schema, f'Unresolved $ref in {schema} (make sure to use {SchemaConverter.resolve_refs.__name__})' schema_type = schema.get('type') rule_name = name or 'root' - if 'oneOf' in schema or 'anyOf' in schema: + if (ref := schema.get('$ref')) is not None: + return self._resolve_ref(ref) + + elif 'oneOf' in schema or 'anyOf' in schema: return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) elif isinstance(schema_type, list): @@ -205,14 +223,17 @@ class SchemaConverter: elif schema_type in (None, 'object') and 'properties' in schema: required = set(schema.get('required', [])) - properties = schema['properties'] - return self._add_rule(rule_name, self._build_object_rule(properties.items(), required, name)) + properties = list(schema['properties'].items()) + return self._add_rule(rule_name, self._build_object_rule(properties, required, name)) elif schema_type in (None, 'object') and 'allOf' in schema: required = set() properties = [] + hybrid_name = name def add_component(comp_schema, is_required): - + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + if 'properties' in comp_schema: for prop_name, prop_schema in comp_schema['properties'].items(): properties.append((prop_name, prop_schema)) @@ -226,7 +247,7 @@ class SchemaConverter: else: add_component(t, is_required=True) - return self._add_rule(rule_name, self._build_object_rule(properties, required, name)) + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name)) elif schema_type in (None, 'object') and 'additionalProperties' in schema: additional_properties = schema['additionalProperties'] @@ -361,17 +382,19 @@ def main(args_in = None): parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') args = parser.parse_args(args_in) - if args.schema.startswith('https://'): + if (url := args.schema.startswith('https://')): import requests - schema = requests.get(args.schema).json() + schema = requests.get(url).json() elif args.schema == '-': + url = 'stdin' schema = json.load(sys.stdin) else: + url = f'file://{args.schema}' with open(args.schema) as f: schema = json.load(f) - schema = SchemaConverter.resolve_refs(schema) prop_order = {name: idx for idx, name in enumerate(args.prop_order)} converter = SchemaConverter(prop_order) + schema = converter.resolve_refs(schema, url) converter.visit(schema, '') print(converter.format_grammar())