json: custom regex parser, adds dot support & JS-portable

This commit is contained in:
ochafik 2024-03-11 00:24:34 +00:00
parent 27b1fefdf4
commit 0e9494183b
2 changed files with 160 additions and 76 deletions

View file

@ -31,9 +31,10 @@ GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']'
class SchemaConverter: class SchemaConverter:
def __init__(self, prop_order, allow_fetch): def __init__(self, *, prop_order, allow_fetch, dotall):
self._prop_order = prop_order self._prop_order = prop_order
self._allow_fetch = allow_fetch self._allow_fetch = allow_fetch
self._dotall = dotall
self._rules = {'space': SPACE_RULE} self._rules = {'space': SPACE_RULE}
self._refs = {} self._refs = {}
self._refs_being_resolved = set() self._refs_being_resolved = set()
@ -113,89 +114,162 @@ class SchemaConverter:
)) ))
def _visit_pattern(self, pattern, name): def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1] pattern = pattern[1:-1]
sub_rule_ids = {} sub_rule_ids = {}
try:
def visit_seq(seq): i = 0
out = [] length = len(pattern)
# Merge consecutive literals
for t, g in itertools.groupby(seq, lambda x: x[0]): def transform() -> Tuple[str, bool]:
g = list(g) '''
if t == re._parser.LITERAL and len(g) > 1: Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
out.append(self._format_literal(''.join(chr(x[1]) for x in g))) '''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = '[\\U00000000-\\U0010FFFF]'
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
return self._add_rule(f'dot', rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
if is_literal:
lit = ''.join(x[0][1:-1] for x in g)
ret.append((f'"{lit}"', True))
else: else:
out.extend(visit(x) for x in g) ret.extend(g)
if len(out) == 1: if len(ret) == 1:
return out[0] return ret[0]
return '(' + ' '.join(out) + ')' return (' '.join(x[0] for x in seq), False)
def visit(pattern):
nonlocal sub_rule_ids
if pattern[0] == re._parser.LITERAL: while i < length:
return json.dumps(chr(pattern[1])) c = pattern[i]
if c == '.':
elif pattern[0] == re._parser.NOT_LITERAL: seq.append((get_dot(), False))
return f'[^{self._format_range_char(chr(pattern[1]))}]' i += 1
elif c == '(':
elif pattern[0] == re._parser.ANY: i += 1
raise ValueError('Unsupported pattern: "."') if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
elif pattern[0] == re._parser.IN: seq.append((f'({transform()[0]})', False))
def format_range_component(c): elif c == ')':
if c[0] == re._parser.LITERAL: i += 1
return self._format_range_char(chr(c[1])) assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
elif c[0] == re._parser.RANGE: return join_seq()
return f'{self._format_range_char(chr(c[1][0]))}-{self._format_range_char(chr(c[1][1]))}' elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else: else:
raise ValueError(f'Unrecognized pattern: {c} (pattern = {pattern})') square_brackets += pattern[i]
components = pattern[1] i += 1
prefix = '' assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
if len(components) > 0 and components[0][0] == re._parser.NEGATE: square_brackets += ']'
prefix = '^' i += 1
components = components[1:] seq.append((square_brackets, False))
return f'[{prefix}{"".join(format_range_component(c) for c in components)}]' elif c == '|':
seq.append(('|', False))
elif pattern[0] == re._parser.BRANCH: i += 1
return '(' + ' | '.join((visit(p) for p in pattern[1][1])) + ')' elif c in ('*', '+', '?'):
seq[-1] = (f'{seq[-1][0]}{c}', False)
elif pattern[0] == re._parser.SUBPATTERN: i += 1
return '(' + visit(pattern[1][3]) + ')' elif c == '{':
curly_brackets = c
elif pattern[0] == re._parser.MAX_REPEAT: i += 1
min_times = pattern[1][0] while i < length and pattern[i] != '}':
max_times = pattern[1][1] if not pattern[1][1] == re._parser.MAXREPEAT else None curly_brackets += pattern[i]
sub = visit(pattern[1][2]) i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
id = sub_rule_ids.get(sub) curly_brackets += '}'
if id is None: i += 1
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
sub_rule_ids[sub] = id if len(nums) == 1:
sub = id min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
(sub, sub_is_literal) = seq[-1]
if min_times == 0 and max_times is None: if min_times == 0 and max_times is None:
return f'{sub}*' seq[-1] = (f'{sub}*', False)
elif min_times == 0 and max_times == 1: elif min_times == 0 and max_times == 1:
return f'{sub}?' seq[-1] = (f'{sub}?', False)
elif min_times == 1 and max_times is None: elif min_times == 1 and max_times is None:
return f'{sub}+' seq[-1] = (f'{sub}+', False)
else: else:
return ' '.join([sub] * min_times + if not sub_is_literal:
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])) id = sub_rule_ids.get(sub)
if id is None:
elif isinstance(pattern, re._parser.SubPattern): id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
return visit_seq(pattern.data) sub_rule_ids[sub] = id
sub = id
elif isinstance(pattern, list):
return visit_seq(pattern) seq[-1] = (
' '.join(
else: ([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)})') ([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
False
)
else:
lit = ''
while i < length and pattern[i] not in ('.', '(', ')', '|', '[', '{', '*', '+', '?') \
and not (i < length - 1 and pattern[i+1] in ('{', '*', '+', '?')):
c = pattern[i]
if c == '\\' and i < length - 1:
i += 1
if c in ('.', '[', ']', '{', '}', '(', ')', '|', '*', '+', '?'):
# Escapes in regular expressions that aren't escaped in GBNF literals
lit += c
else:
lit += f'\\{c}'
i += 1
else:
lit += c
i += 1
if lit:
seq.append((f'"{lit}"', True))
if i < length and pattern[i] not in ('.', '(', ')', '|', '[', '{', '*', '+', '?'):
seq.append((f'"{pattern[i]}"', True))
i += 1
return join_seq()
return self._add_rule(name, transform()[0])
return self._add_rule(name, visit(re._parser.parse(pattern)))
except BaseException as e:
raise Exception(f'Error processing pattern: {pattern}: {e}') from e
def _resolve_ref(self, ref): def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1] ref_name = ref.split('/')[-1]
@ -392,8 +466,15 @@ def main(args_in = None):
) )
parser.add_argument( parser.add_argument(
'--allow-fetch', '--allow-fetch',
action='store_true',
default=False, default=False,
help='Whether to allow fetching referenced schemas over HTTPS') help='Whether to allow fetching referenced schemas over HTTPS')
parser.add_argument(
'--dotall',
action='store_true',
default=False,
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
@ -408,8 +489,10 @@ def main(args_in = None):
url = f'file://{args.schema}' url = f'file://{args.schema}'
with open(args.schema) as f: with open(args.schema) as f:
schema = json.load(f) schema = json.load(f)
prop_order = {name: idx for idx, name in enumerate(args.prop_order)} converter = SchemaConverter(
converter = SchemaConverter(prop_order, args.allow_fetch) prop_order={name: idx for idx, name in enumerate(args.prop_order)},
allow_fetch=args.allow_fetch,
dotall=args.dotall)
schema = converter.resolve_refs(schema, url) schema = converter.resolve_refs(schema, url)
converter.visit(schema, '') converter.visit(schema, '')
print(converter.format_grammar()) print(converter.format_grammar())

View file

@ -1,7 +1,7 @@
import json, subprocess, sys, os import json, subprocess, sys, os
assert len(sys.argv) == 2 assert len(sys.argv) >= 2
[_, pattern] = sys.argv [_, pattern, *rest] = sys.argv
print(subprocess.check_output( print(subprocess.check_output(
[ [
@ -9,6 +9,7 @@ print(subprocess.check_output(
os.path.join( os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.dirname(os.path.realpath(__file__)),
"json-schema-to-grammar.py"), "json-schema-to-grammar.py"),
*rest,
"-", "-",
], ],
text=True, text=True,