From 0e9494183bdfe78e07040c73c37cb867b4e75b31 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 11 Mar 2024 00:24:34 +0000 Subject: [PATCH] json: custom regex parser, adds dot support & JS-portable --- examples/json-schema-to-grammar.py | 231 ++++++++++++++++++++--------- examples/regex-to-grammar.py | 5 +- 2 files changed, 160 insertions(+), 76 deletions(-) diff --git a/examples/json-schema-to-grammar.py b/examples/json-schema-to-grammar.py index 4eb0f6c66..a8ce11bdd 100755 --- a/examples/json-schema-to-grammar.py +++ b/examples/json-schema-to-grammar.py @@ -31,9 +31,10 @@ GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']' class SchemaConverter: - def __init__(self, prop_order, allow_fetch): + def __init__(self, *, prop_order, allow_fetch, dotall): self._prop_order = prop_order self._allow_fetch = allow_fetch + self._dotall = dotall self._rules = {'space': SPACE_RULE} self._refs = {} self._refs_being_resolved = set() @@ -113,89 +114,162 @@ class SchemaConverter: )) def _visit_pattern(self, pattern, name): + ''' + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + ''' + assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' pattern = pattern[1:-1] sub_rule_ids = {} - try: - def visit_seq(seq): - out = [] - # Merge consecutive literals - for t, g in itertools.groupby(seq, lambda x: x[0]): - g = list(g) - if t == re._parser.LITERAL and len(g) > 1: - out.append(self._format_literal(''.join(chr(x[1]) for x in g))) + + i = 0 + length = len(pattern) + + def transform() -> Tuple[str, bool]: + ''' + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + ''' + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[Tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = '[\\U00000000-\\U0010FFFF]' + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]' + return self._add_rule(f'dot', rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in itertools.groupby(seq, lambda x: x[1]): + if is_literal: + lit = ''.join(x[0][1:-1] for x in g) + ret.append((f'"{lit}"', True)) else: - out.extend(visit(x) for x in g) - if len(out) == 1: - return out[0] - return '(' + ' '.join(out) + ')' - - def visit(pattern): - nonlocal sub_rule_ids + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (' '.join(x[0] for x in seq), False) - if pattern[0] == re._parser.LITERAL: - return json.dumps(chr(pattern[1])) - - elif pattern[0] == re._parser.NOT_LITERAL: - return f'[^{self._format_range_char(chr(pattern[1]))}]' - - elif pattern[0] == re._parser.ANY: - raise ValueError('Unsupported pattern: "."') - - elif pattern[0] == re._parser.IN: - def format_range_component(c): - if c[0] == re._parser.LITERAL: - return self._format_range_char(chr(c[1])) - elif c[0] == re._parser.RANGE: - return f'{self._format_range_char(chr(c[1][0]))}-{self._format_range_char(chr(c[1][1]))}' + while i < length: + c = pattern[i] + if c == '.': + seq.append((get_dot(), False)) + i += 1 + elif c == '(': + i += 1 + if i < length: + assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f'({transform()[0]})', False)) + elif c == ')': + i += 1 + assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + return join_seq() + elif c == '[': + square_brackets = c + i += 1 + while i < length and pattern[i] != ']': + if pattern[i] == '\\': + square_brackets += pattern[i:i+2] + i += 2 else: - raise ValueError(f'Unrecognized pattern: {c} (pattern = {pattern})') - components = pattern[1] - prefix = '' - if len(components) > 0 and components[0][0] == re._parser.NEGATE: - prefix = '^' - components = components[1:] - return f'[{prefix}{"".join(format_range_component(c) for c in components)}]' - - elif pattern[0] == re._parser.BRANCH: - return '(' + ' | '.join((visit(p) for p in pattern[1][1])) + ')' - - elif pattern[0] == re._parser.SUBPATTERN: - return '(' + visit(pattern[1][3]) + ')' - - elif pattern[0] == re._parser.MAX_REPEAT: - min_times = pattern[1][0] - max_times = pattern[1][1] if not pattern[1][1] == re._parser.MAXREPEAT else None - sub = visit(pattern[1][2]) - - id = sub_rule_ids.get(sub) - if id is None: - id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) - sub_rule_ids[sub] = id - sub = id + square_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' + square_brackets += ']' + i += 1 + seq.append((square_brackets, False)) + elif c == '|': + seq.append(('|', False)) + i += 1 + elif c in ('*', '+', '?'): + seq[-1] = (f'{seq[-1][0]}{c}', False) + i += 1 + elif c == '{': + curly_brackets = c + i += 1 + while i < length and pattern[i] != '}': + curly_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' + curly_brackets += '}' + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + + (sub, sub_is_literal) = seq[-1] if min_times == 0 and max_times is None: - return f'{sub}*' + seq[-1] = (f'{sub}*', False) elif min_times == 0 and max_times == 1: - return f'{sub}?' + seq[-1] = (f'{sub}?', False) elif min_times == 1 and max_times is None: - return f'{sub}+' + seq[-1] = (f'{sub}+', False) else: - return ' '.join([sub] * min_times + - ([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])) - - elif isinstance(pattern, re._parser.SubPattern): - return visit_seq(pattern.data) - - elif isinstance(pattern, list): - return visit_seq(pattern) - - else: - raise ValueError(f'Unrecognized pattern: {pattern} ({type(pattern)})') + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = ( + ' '.join( + ([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) + + ([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])), + False + ) + else: + lit = '' + while i < length and pattern[i] not in ('.', '(', ')', '|', '[', '{', '*', '+', '?') \ + and not (i < length - 1 and pattern[i+1] in ('{', '*', '+', '?')): + c = pattern[i] + if c == '\\' and i < length - 1: + i += 1 + if c in ('.', '[', ']', '{', '}', '(', ')', '|', '*', '+', '?'): + # Escapes in regular expressions that aren't escaped in GBNF literals + lit += c + else: + lit += f'\\{c}' + i += 1 + else: + lit += c + i += 1 + if lit: + seq.append((f'"{lit}"', True)) + + if i < length and pattern[i] not in ('.', '(', ')', '|', '[', '{', '*', '+', '?'): + seq.append((f'"{pattern[i]}"', True)) + i += 1 + + return join_seq() + + return self._add_rule(name, transform()[0]) - return self._add_rule(name, visit(re._parser.parse(pattern))) - except BaseException as e: - raise Exception(f'Error processing pattern: {pattern}: {e}') from e def _resolve_ref(self, ref): ref_name = ref.split('/')[-1] @@ -392,8 +466,15 @@ def main(args_in = None): ) parser.add_argument( '--allow-fetch', + action='store_true', default=False, help='Whether to allow fetching referenced schemas over HTTPS') + parser.add_argument( + '--dotall', + action='store_true', + default=False, + help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') + parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') args = parser.parse_args(args_in) @@ -408,8 +489,10 @@ def main(args_in = None): url = f'file://{args.schema}' with open(args.schema) as f: schema = json.load(f) - prop_order = {name: idx for idx, name in enumerate(args.prop_order)} - converter = SchemaConverter(prop_order, args.allow_fetch) + converter = SchemaConverter( + prop_order={name: idx for idx, name in enumerate(args.prop_order)}, + allow_fetch=args.allow_fetch, + dotall=args.dotall) schema = converter.resolve_refs(schema, url) converter.visit(schema, '') print(converter.format_grammar()) diff --git a/examples/regex-to-grammar.py b/examples/regex-to-grammar.py index cec4fe01e..5671332e6 100644 --- a/examples/regex-to-grammar.py +++ b/examples/regex-to-grammar.py @@ -1,7 +1,7 @@ import json, subprocess, sys, os -assert len(sys.argv) == 2 -[_, pattern] = sys.argv +assert len(sys.argv) >= 2 +[_, pattern, *rest] = sys.argv print(subprocess.check_output( [ @@ -9,6 +9,7 @@ print(subprocess.check_output( os.path.join( os.path.dirname(os.path.realpath(__file__)), "json-schema-to-grammar.py"), + *rest, "-", ], text=True,