json: custom regex parser, adds dot support & JS-portable

This commit is contained in:
ochafik 2024-03-11 00:24:34 +00:00
parent 27b1fefdf4
commit 0e9494183b
2 changed files with 160 additions and 76 deletions

View file

@ -31,9 +31,10 @@ GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']'
class SchemaConverter:
def __init__(self, prop_order, allow_fetch):
def __init__(self, *, prop_order, allow_fetch, dotall):
self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._rules = {'space': SPACE_RULE}
self._refs = {}
self._refs_being_resolved = set()
@ -113,89 +114,162 @@ class SchemaConverter:
))
def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
try:
def visit_seq(seq):
out = []
# Merge consecutive literals
for t, g in itertools.groupby(seq, lambda x: x[0]):
g = list(g)
if t == re._parser.LITERAL and len(g) > 1:
out.append(self._format_literal(''.join(chr(x[1]) for x in g)))
else:
out.extend(visit(x) for x in g)
if len(out) == 1:
return out[0]
return '(' + ' '.join(out) + ')'
def visit(pattern):
i = 0
length = len(pattern)
def transform() -> Tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
if pattern[0] == re._parser.LITERAL:
return json.dumps(chr(pattern[1]))
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
elif pattern[0] == re._parser.NOT_LITERAL:
return f'[^{self._format_range_char(chr(pattern[1]))}]'
elif pattern[0] == re._parser.ANY:
raise ValueError('Unsupported pattern: "."')
elif pattern[0] == re._parser.IN:
def format_range_component(c):
if c[0] == re._parser.LITERAL:
return self._format_range_char(chr(c[1]))
elif c[0] == re._parser.RANGE:
return f'{self._format_range_char(chr(c[1][0]))}-{self._format_range_char(chr(c[1][1]))}'
def get_dot():
if self._dotall:
rule = '[\\U00000000-\\U0010FFFF]'
else:
raise ValueError(f'Unrecognized pattern: {c} (pattern = {pattern})')
components = pattern[1]
prefix = ''
if len(components) > 0 and components[0][0] == re._parser.NEGATE:
prefix = '^'
components = components[1:]
return f'[{prefix}{"".join(format_range_component(c) for c in components)}]'
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
return self._add_rule(f'dot', rule)
elif pattern[0] == re._parser.BRANCH:
return '(' + ' | '.join((visit(p) for p in pattern[1][1])) + ')'
def join_seq():
nonlocal seq
ret = []
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
if is_literal:
lit = ''.join(x[0][1:-1] for x in g)
ret.append((f'"{lit}"', True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (' '.join(x[0] for x in seq), False)
elif pattern[0] == re._parser.SUBPATTERN:
return '(' + visit(pattern[1][3]) + ')'
while i < length:
c = pattern[i]
if c == '.':
seq.append((get_dot(), False))
i += 1
elif c == '(':
i += 1
if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f'({transform()[0]})', False))
elif c == ')':
i += 1
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
return join_seq()
elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
square_brackets += ']'
i += 1
seq.append((square_brackets, False))
elif c == '|':
seq.append(('|', False))
i += 1
elif c in ('*', '+', '?'):
seq[-1] = (f'{seq[-1][0]}{c}', False)
i += 1
elif c == '{':
curly_brackets = c
i += 1
while i < length and pattern[i] != '}':
curly_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
curly_brackets += '}'
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
elif pattern[0] == re._parser.MAX_REPEAT:
min_times = pattern[1][0]
max_times = pattern[1][1] if not pattern[1][1] == re._parser.MAXREPEAT else None
sub = visit(pattern[1][2])
(sub, sub_is_literal) = seq[-1]
if min_times == 0 and max_times is None:
seq[-1] = (f'{sub}*', False)
elif min_times == 0 and max_times == 1:
seq[-1] = (f'{sub}?', False)
elif min_times == 1 and max_times is None:
seq[-1] = (f'{sub}+', False)
else:
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
if min_times == 0 and max_times is None:
return f'{sub}*'
elif min_times == 0 and max_times == 1:
return f'{sub}?'
elif min_times == 1 and max_times is None:
return f'{sub}+'
seq[-1] = (
' '.join(
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
False
)
else:
return ' '.join([sub] * min_times +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*']))
elif isinstance(pattern, re._parser.SubPattern):
return visit_seq(pattern.data)
elif isinstance(pattern, list):
return visit_seq(pattern)
lit = ''
while i < length and pattern[i] not in ('.', '(', ')', '|', '[', '{', '*', '+', '?') \
and not (i < length - 1 and pattern[i+1] in ('{', '*', '+', '?')):
c = pattern[i]
if c == '\\' and i < length - 1:
i += 1
if c in ('.', '[', ']', '{', '}', '(', ')', '|', '*', '+', '?'):
# Escapes in regular expressions that aren't escaped in GBNF literals
lit += c
else:
raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)})')
lit += f'\\{c}'
i += 1
else:
lit += c
i += 1
if lit:
seq.append((f'"{lit}"', True))
if i < length and pattern[i] not in ('.', '(', ')', '|', '[', '{', '*', '+', '?'):
seq.append((f'"{pattern[i]}"', True))
i += 1
return join_seq()
return self._add_rule(name, transform()[0])
return self._add_rule(name, visit(re._parser.parse(pattern)))
except BaseException as e:
raise Exception(f'Error processing pattern: {pattern}: {e}') from e
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
@ -392,8 +466,15 @@ def main(args_in = None):
)
parser.add_argument(
'--allow-fetch',
action='store_true',
default=False,
help='Whether to allow fetching referenced schemas over HTTPS')
parser.add_argument(
'--dotall',
action='store_true',
default=False,
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
args = parser.parse_args(args_in)
@ -408,8 +489,10 @@ def main(args_in = None):
url = f'file://{args.schema}'
with open(args.schema) as f:
schema = json.load(f)
prop_order = {name: idx for idx, name in enumerate(args.prop_order)}
converter = SchemaConverter(prop_order, args.allow_fetch)
converter = SchemaConverter(
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
allow_fetch=args.allow_fetch,
dotall=args.dotall)
schema = converter.resolve_refs(schema, url)
converter.visit(schema, '')
print(converter.format_grammar())

View file

@ -1,7 +1,7 @@
import json, subprocess, sys, os
assert len(sys.argv) == 2
[_, pattern] = sys.argv
assert len(sys.argv) >= 2
[_, pattern, *rest] = sys.argv
print(subprocess.check_output(
[
@ -9,6 +9,7 @@ print(subprocess.check_output(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"json-schema-to-grammar.py"),
*rest,
"-",
],
text=True,