json: handle pattern repetitions

This commit is contained in:
ochafik 2024-03-05 03:40:23 +00:00
parent d5ef412f31
commit 4e7c26c32c

View file

@ -73,6 +73,19 @@ class SchemaConverter:
for i, alt_schema in enumerate(alt_schemas) for i, alt_schema in enumerate(alt_schemas)
)) ))
def _format_range_char(self, c):
if c in ('-', ']', '\\'):
return '\\' + chr(c)
elif c == '\n':
return '\\n'
elif c == '\r':
return '\\r'
elif c == '\t':
return '\\t'
else:
return c
def _visit_pattern(self, pattern): def _visit_pattern(self, pattern):
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1] pattern = pattern[1:-1]
@ -81,6 +94,7 @@ class SchemaConverter:
out = [] out = []
for t, g in itertools.groupby(seq, lambda x: x[0]): for t, g in itertools.groupby(seq, lambda x: x[0]):
g = list(g) g = list(g)
# Merge consecutive literals
if t == re._parser.LITERAL and len(g) > 1: if t == re._parser.LITERAL and len(g) > 1:
out.append(self._format_literal(''.join(chr(x[1]) for x in g))) out.append(self._format_literal(''.join(chr(x[1]) for x in g)))
else: else:
@ -92,48 +106,53 @@ class SchemaConverter:
def visit(pattern): def visit(pattern):
if pattern[0] == re._parser.LITERAL: if pattern[0] == re._parser.LITERAL:
return json.dumps(chr(pattern[1])) return json.dumps(chr(pattern[1]))
elif pattern[0] == re._parser.NOT_LITERAL: elif pattern[0] == re._parser.NOT_LITERAL:
ch = chr(pattern[1]) return f'[^{self._format_range_char(chr(pattern[1]))}]'
esc_ch = '\\' + ch if ch in ('-', ']', '\\') else ch
return f'[^{esc_ch}]'
elif pattern[0] == re._parser.ANY: elif pattern[0] == re._parser.ANY:
raise ValueError('Unsupported pattern: "."') raise ValueError('Unsupported pattern: "."')
elif pattern[0] == re._parser.IN: elif pattern[0] == re._parser.IN:
def format_range_char(c):
if chr(c) in ('-', ']', '\\', '\n', '\r', '\t'):
return '\\' + chr(c)
else:
return chr(c)
def format_range_comp(c): def format_range_comp(c):
if c[0] == re._parser.LITERAL: if c[0] == re._parser.LITERAL:
return format_range_char(c[1]) return self._format_range_char(chr(c[1]))
elif c[0] == re._parser.RANGE: elif c[0] == re._parser.RANGE:
return f'{format_range_char(c[1][0])}-{format_range_char(c[1][1])}' return f'{self._format_range_char(chr(c[1][0]))}-{self._format_range_char(chr(c[1][1]))}'
else: else:
raise ValueError(f'Unrecognized pattern: {c}') raise ValueError(f'Unrecognized pattern: {c}')
return f'[{"".join(format_range_comp(c) for c in pattern[1])}]' return f'[{"".join(format_range_comp(c) for c in pattern[1])}]'
elif pattern[0] == re._parser.BRANCH: elif pattern[0] == re._parser.BRANCH:
return ' | '.join((visit(p) for p in pattern[1][1])) return '(' + ' | '.join((visit(p) for p in pattern[1][1])) + ')'
elif pattern[0] == re._parser.SUBPATTERN: elif pattern[0] == re._parser.SUBPATTERN:
return visit(pattern[1][3]) return '(' + visit(pattern[1][3]) + ')'
elif pattern[0] == re._parser.MAX_REPEAT: elif pattern[0] == re._parser.MAX_REPEAT:
min_times = pattern[1][0] min_times = pattern[1][0]
max_times = pattern[1][1] if not pattern[1][1] == re._parser.MAXREPEAT else None max_times = pattern[1][1] if not pattern[1][1] == re._parser.MAXREPEAT else None
sub_pattern = pattern[1][2] sub = visit(pattern[1][2])
if min_times == 0 and max_times is None: if min_times == 0 and max_times is None:
return f'{visit(sub_pattern)}*' return f'{sub}*'
elif min_times == 0 and max_times == 1: elif min_times == 0 and max_times == 1:
return f'{visit(sub_pattern)}?' return f'{sub}?'
elif min_times == 1 and max_times is None: elif min_times == 1 and max_times is None:
return f'{visit(sub_pattern)}+' return f'{sub}+'
else: else:
raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)}; min: {min_times}, max: {max_times})') return ' '.join([sub] * min_times +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*']))
elif isinstance(pattern, re._parser.SubPattern): elif isinstance(pattern, re._parser.SubPattern):
return visit_seq(pattern.data) return visit_seq(pattern.data)
elif isinstance(pattern, list): elif isinstance(pattern, list):
return visit_seq(pattern) return visit_seq(pattern)
else: else:
raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)})') raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)})')
return visit(re._parser.parse(pattern)) return visit(re._parser.parse(pattern))
except BaseException as e: except BaseException as e:
raise Exception(f'Error processing pattern: {pattern}: {e}') from e raise Exception(f'Error processing pattern: {pattern}: {e}') from e