json: handle pattern repetitions

This commit is contained in:
ochafik 2024-03-05 03:40:23 +00:00
parent d5ef412f31
commit 4e7c26c32c

View file

@ -73,6 +73,19 @@ class SchemaConverter:
for i, alt_schema in enumerate(alt_schemas)
))
def _format_range_char(self, c):
if c in ('-', ']', '\\'):
return '\\' + chr(c)
elif c == '\n':
return '\\n'
elif c == '\r':
return '\\r'
elif c == '\t':
return '\\t'
else:
return c
def _visit_pattern(self, pattern):
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
@ -81,6 +94,7 @@ class SchemaConverter:
out = []
for t, g in itertools.groupby(seq, lambda x: x[0]):
g = list(g)
# Merge consecutive literals
if t == re._parser.LITERAL and len(g) > 1:
out.append(self._format_literal(''.join(chr(x[1]) for x in g)))
else:
@ -92,48 +106,53 @@ class SchemaConverter:
def visit(pattern):
if pattern[0] == re._parser.LITERAL:
return json.dumps(chr(pattern[1]))
elif pattern[0] == re._parser.NOT_LITERAL:
ch = chr(pattern[1])
esc_ch = '\\' + ch if ch in ('-', ']', '\\') else ch
return f'[^{esc_ch}]'
return f'[^{self._format_range_char(chr(pattern[1]))}]'
elif pattern[0] == re._parser.ANY:
raise ValueError('Unsupported pattern: "."')
elif pattern[0] == re._parser.IN:
def format_range_char(c):
if chr(c) in ('-', ']', '\\', '\n', '\r', '\t'):
return '\\' + chr(c)
else:
return chr(c)
def format_range_comp(c):
if c[0] == re._parser.LITERAL:
return format_range_char(c[1])
return self._format_range_char(chr(c[1]))
elif c[0] == re._parser.RANGE:
return f'{format_range_char(c[1][0])}-{format_range_char(c[1][1])}'
return f'{self._format_range_char(chr(c[1][0]))}-{self._format_range_char(chr(c[1][1]))}'
else:
raise ValueError(f'Unrecognized pattern: {c}')
return f'[{"".join(format_range_comp(c) for c in pattern[1])}]'
elif pattern[0] == re._parser.BRANCH:
return ' | '.join((visit(p) for p in pattern[1][1]))
return '(' + ' | '.join((visit(p) for p in pattern[1][1])) + ')'
elif pattern[0] == re._parser.SUBPATTERN:
return visit(pattern[1][3])
return '(' + visit(pattern[1][3]) + ')'
elif pattern[0] == re._parser.MAX_REPEAT:
min_times = pattern[1][0]
max_times = pattern[1][1] if not pattern[1][1] == re._parser.MAXREPEAT else None
sub_pattern = pattern[1][2]
sub = visit(pattern[1][2])
if min_times == 0 and max_times is None:
return f'{visit(sub_pattern)}*'
return f'{sub}*'
elif min_times == 0 and max_times == 1:
return f'{visit(sub_pattern)}?'
return f'{sub}?'
elif min_times == 1 and max_times is None:
return f'{visit(sub_pattern)}+'
return f'{sub}+'
else:
raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)}; min: {min_times}, max: {max_times})')
return ' '.join([sub] * min_times +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*']))
elif isinstance(pattern, re._parser.SubPattern):
return visit_seq(pattern.data)
elif isinstance(pattern, list):
return visit_seq(pattern)
else:
raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)})')
return visit(re._parser.parse(pattern))
except BaseException as e:
raise Exception(f'Error processing pattern: {pattern}: {e}') from e