json: simplify range escapes

This commit is contained in:
ochafik 2024-03-10 17:32:45 +00:00
parent f57b467c74
commit d1fda6f450

View file

@ -26,7 +26,8 @@ PRIMITIVE_RULES = {
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'}
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
class SchemaConverter:
@ -43,6 +44,11 @@ class SchemaConverter:
)
return f'"{escaped}"'
def _format_range_char(self, literal):
return GRAMMAR_RANGE_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)[1:-1]
)
def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
@ -106,18 +112,6 @@ class SchemaConverter:
for i, alt_schema in enumerate(alt_schemas)
))
def _format_range_char(self, c):
if c in ('-', ']', '\\'):
return '\\' + chr(c)
elif c == '\n':
return '\\n'
elif c == '\r':
return '\\r'
elif c == '\t':
return '\\t'
else:
return c
def _visit_pattern(self, pattern, name):
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
@ -125,9 +119,9 @@ class SchemaConverter:
try:
def visit_seq(seq):
out = []
# Merge consecutive literals
for t, g in itertools.groupby(seq, lambda x: x[0]):
g = list(g)
# Merge consecutive literals
if t == re._parser.LITERAL and len(g) > 1:
out.append(self._format_literal(''.join(chr(x[1]) for x in g)))
else:
@ -149,14 +143,14 @@ class SchemaConverter:
raise ValueError('Unsupported pattern: "."')
elif pattern[0] == re._parser.IN:
def format_range_comp(c):
def format_range_component(c):
if c[0] == re._parser.LITERAL:
return self._format_range_char(chr(c[1]))
elif c[0] == re._parser.RANGE:
return f'{self._format_range_char(chr(c[1][0]))}-{self._format_range_char(chr(c[1][1]))}'
else:
raise ValueError(f'Unrecognized pattern: {c}')
return f'[{"".join(format_range_comp(c) for c in pattern[1])}]'
return f'[{"".join(format_range_component(c) for c in pattern[1])}]'
elif pattern[0] == re._parser.BRANCH:
return '(' + ' | '.join((visit(p) for p in pattern[1][1])) + ')'