json: unify all repetition code (w/ or w/o sep)

This commit is contained in:
ochafik 2024-04-08 23:06:42 +01:00
parent dcf5d3283a
commit 181f984def
5 changed files with 1885 additions and 1862 deletions

View file

@ -11,21 +11,63 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
static std::string build_repetition(const std::string & content, int upToN) { template <typename Iterator>
std::ostringstream out; static std::string join(Iterator begin, Iterator end, const std::string & separator);
std::function<void(int)> aux = [&](int n) {
if (n == 0) { static std::string repeat(const std::string & str, size_t n);
return;
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) {
if (separator_rule.empty()) {
if (min_items == 0 && max_items == 1) {
return item_rule + "?";
} else if (min_items == 1 && max_items == std::numeric_limits<int>::max()) {
return item_rule + "+";
} }
out << "(" << content; }
if (n > 1) {
out << " "; std::string result;
aux(n - 1); if (min_items > 0) {
if (item_rule_is_literal && separator_rule.empty()) {
result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\"";
} else {
std::vector<std::string> items(min_items, item_rule);
result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " ");
} }
out << ")?"; }
std::function<std::string(int, bool)> opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string {
if (up_to_n == 0) {
return "";
}
std::string res;
if (!separator_rule.empty() && prefix_with_sep) {
res = separator_rule + " " + item_rule;
} else {
res = item_rule;
}
if (up_to_n > 1) {
res += " " + opt_repetitions(up_to_n - 1, true);
}
return "(" + res + ")?";
}; };
aux(upToN);
return out.str(); if (min_items > 0 && max_items != min_items) {
result += " ";
}
if (max_items != std::numeric_limits<int>::max()) {
result += opt_repetitions(max_items - min_items, min_items > 0);
} else {
std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")";
if (min_items == 0 && !separator_rule.empty()) {
result = "(" + item_rule + " " + item_operator + "*)?";
} else {
result += item_operator + "*";
}
}
return result;
} }
const std::string SPACE_RULE = "\" \"?"; const std::string SPACE_RULE = "\" \"?";
@ -35,7 +77,7 @@ struct BuiltinRule {
std::vector<std::string> deps; std::vector<std::string> deps;
}; };
const std::string _up_to_15_digits = build_repetition("[0-9]", 15); const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15);
std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = { std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"boolean", {"(\"true\" | \"false\") space", {}}}, {"boolean", {"(\"true\" | \"false\") space", {}}},
@ -333,42 +375,21 @@ private:
auto &sub = last.first; auto &sub = last.first;
auto sub_is_literal = last.second; auto sub_is_literal = last.second;
if (min_times == 0 && max_times == std::numeric_limits<int>::max()) { if (!sub_is_literal) {
sub += "*"; std::string & sub_id = sub_rule_ids[sub];
} else if (min_times == 0 && max_times == 1) { if (sub_id.empty()) {
sub += "?"; sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
} else if (min_times == 1 && max_times == std::numeric_limits<int>::max()) {
sub += "+";
} else {
if (!sub_is_literal) {
std::string & sub_id = sub_rule_ids[sub];
if (sub_id.empty()) {
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
}
sub = sub_id;
} }
std::string result; sub = sub_id;
if (sub_is_literal && min_times > 0) {
result = "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\"";
} else {
for (int j = 0; j < min_times; j++) {
if (j > 0) {
result += " ";
}
result += sub;
}
}
if (min_times > 0 && min_times < max_times) {
result += " ";
}
if (max_times == std::numeric_limits<int>::max()) {
result += sub + "*";
} else {
result += build_repetition(sub, max_times - min_times);
}
seq.back().first = result;
seq.back().second = false;
} }
seq.back().first = build_repetition(
sub_is_literal ? "\"" + sub + "\"" : sub,
min_times,
max_times,
"",
sub_is_literal
);
seq.back().second = false;
} else { } else {
std::string literal; std::string literal;
auto is_non_literal = [&](char c) { auto is_non_literal = [&](char c) {
@ -686,27 +707,11 @@ public:
return _add_rule(rule_name, rule); return _add_rule(rule_name, rule);
} else { } else {
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
std::string list_item_operator = "( \",\" space " + item_rule_name + " )";
std::string successive_items;
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0; int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : -1; int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
if (min_items > 0) {
successive_items += repeat(list_item_operator, min_items - 1); return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
min_items--;
}
if (max_items >= 0 && max_items > min_items) {
successive_items += build_repetition(list_item_operator, max_items - min_items - 1);
} else {
successive_items += list_item_operator + "*";
}
std::string rule;
if (min_items == 0) {
rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space";
} else {
rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space";
}
return _add_rule(rule_name, rule);
} }
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name); return _visit_pattern(schema["pattern"], rule_name);

View file

@ -6,18 +6,52 @@ import re
import sys import sys
from typing import Any, Dict, List, Set, Tuple, Union from typing import Any, Dict, List, Set, Tuple, Union
def _build_repetition(content, up_to_n): def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
# return ' '.join([content] * n) if not separator_rule:
if up_to_n == 0: if min_items == 0 and max_items == 1:
return '' return f'{item_rule}?'
return f'({content}{" " + _build_repetition(content, up_to_n-1) if up_to_n > 1 else ""})?' elif min_items == 1 and max_items is None:
return f'{item_rule}+'
result = ''
if min_items > 0:
if item_rule_is_literal and separator_rule is None:
result = '"' + (item_rule[1:-1] * min_items) + '"'
else:
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
def opt_repetitions(up_to_n, prefix_with_sep=False):
if up_to_n == 0:
return ''
res = separator_rule + ' ' + item_rule if separator_rule and prefix_with_sep else item_rule
if up_to_n > 1:
res += ' ' + opt_repetitions(up_to_n - 1, prefix_with_sep=True)
return f'({res})?'
if min_items > 0 and max_items != min_items:
result += ' '
if max_items is not None:
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
else:
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
if min_items == 0 and separator_rule:
result = f'({item_rule} {item_operator}*)?'
else:
result += f'{item_operator}*'
return result
class BuiltinRule: class BuiltinRule:
def __init__(self, content: str, deps: list[str] = None): def __init__(self, content: str, deps: list[str] = None):
self.content = content self.content = content
self.deps = deps or [] self.deps = deps or []
_up_to_15_digits = _build_repetition('[0-9]', 15) _up_to_15_digits = _build_repetition('[0-9]', 0, 15)
# whitespace is constrained to a single space char to prevent model "running away" in # whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality? # whitespace. Also maybe improves generation quality?
@ -286,34 +320,14 @@ class SchemaConverter:
(sub, sub_is_literal) = seq[-1] (sub, sub_is_literal) = seq[-1]
if min_times == 0 and max_times is None: if not sub_is_literal:
sub = f'"{sub}"' if sub_is_literal else sub id = sub_rule_ids.get(sub)
seq[-1] = (f'{sub}*', False) if id is None:
elif min_times == 0 and max_times == 1: id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub = f'"{sub}"' if sub_is_literal else sub sub_rule_ids[sub] = id
seq[-1] = (f'{sub}?', False) sub = id
elif min_times == 1 and max_times is None:
sub = f'"{sub}"' if sub_is_literal else sub
seq[-1] = (f'{sub}+', False)
else:
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
if sub_is_literal and min_times > 0: seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
result = '"' + (sub[1:-1] * min_times) + '"'
else:
result = ' '.join([sub] * min_times)
if min_times < max_times:
if min_times > 0:
result += ' '
result += _build_repetition(sub, max_times - min_times)
seq[-1] = (result, False)
else: else:
literal = '' literal = ''
while i < length: while i < length:
@ -421,22 +435,9 @@ class SchemaConverter:
' "]" space') ' "]" space')
else: else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
list_item_operator = f'( "," space {item_rule_name} )'
successive_items = ""
min_items = schema.get("minItems", 0) min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems") max_items = schema.get("maxItems")
if min_items > 0: return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
if max_items is not None and max_items > min_items:
successive_items += _build_repetition(list_item_operator, max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
if min_items == 0:
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space'
else:
rule = f'"[" space {item_rule_name} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'string') and 'pattern' in schema: elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name) return self._visit_pattern(schema['pattern'], rule_name)

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,56 @@
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first. // WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '" "?'; const SPACE_RULE = '" "?';
function _buildRepetition(content, upToN) { function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
if (upToN === 0) { const separatorRule = opts.separatorRule ?? '';
return ''; const itemRuleIsLiteral = opts.itemRuleIsLiteral ?? false
if (separatorRule === '') {
if (minItems === 0 && maxItems === 1) {
return `${itemRule}?`;
} else if (minItems === 1 && maxItems === undefined) {
return `${itemRule}+`;
}
} }
return `(${content}${upToN > 1 ? ` ${_buildRepetition(content, upToN - 1)}` : ''})?`;
let result = '';
if (minItems > 0) {
if (itemRuleIsLiteral && separatorRule === '') {
result = `"${itemRule.slice(1, -1).repeat(minItems)}"`;
} else {
result = Array.from({ length: minItems }, () => itemRule)
.join(separatorRule !== '' ? ` ${separatorRule} ` : ' ');
}
}
const optRepetitions = (upToN, prefixWithSep=false) => {
if (upToN === 0) {
return '';
}
let res = separatorRule !== '' && prefixWithSep ? separatorRule + ' ' + itemRule : itemRule;
if (upToN > 1) {
res += ' ' + optRepetitions(upToN - 1, true);
}
return `(${res})?`;
};
if (minItems > 0 && maxItems !== minItems) {
result += ' ';
}
if (maxItems !== undefined) {
result += optRepetitions(maxItems - minItems, minItems > 0);
} else {
const itemOperator = `(${separatorRule !== '' ? separatorRule + ' ' : ''}${itemRule})`;
if (minItems === 0 && separatorRule !== '') {
result = `(${itemRule} ${itemOperator}*)?`;
} else {
result += `${itemOperator}*`;
}
}
return result;
} }
class BuiltinRule { class BuiltinRule {
@ -15,7 +60,7 @@ class BuiltinRule {
} }
} }
const UP_TO_15_DIGITS = _buildRepetition('[0-9]', 15); const UP_TO_15_DIGITS = _buildRepetition('[0-9]', 0, 15);
const PRIMITIVE_RULES = { const PRIMITIVE_RULES = {
boolean : new BuiltinRule('("true" | "false") space', []), boolean : new BuiltinRule('("true" | "false") space', []),
@ -276,37 +321,19 @@ export class SchemaConverter {
let [sub, subIsLiteral] = seq[seq.length - 1]; let [sub, subIsLiteral] = seq[seq.length - 1];
if (minTimes === 0 && maxTimes === Infinity) { if (!subIsLiteral) {
seq[seq.length - 1] = [`${sub}*`, false]; let id = subRuleIds[sub];
} else if (minTimes === 0 && maxTimes === 1) { if (id === undefined) {
seq[seq.length - 1] = [`${sub}?`, false]; id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub);
} else if (minTimes === 1 && maxTimes === Infinity) { subRuleIds[sub] = id;
seq[seq.length - 1] = [`${sub}+`, false];
} else {
if (!subIsLiteral) {
let id = subRuleIds[sub];
if (id === undefined) {
id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub);
subRuleIds[sub] = id;
}
sub = id;
} }
sub = id;
let result;
if (subIsLiteral && minTimes > 0) {
result = `"${sub.slice(1, -1).repeat(minTimes)}"`;
} else {
result = Array.from({ length: minTimes }, () => sub).join(' ');
}
if (minTimes < maxTimes) {
if (minTimes > 0) {
result += ' ';
}
result += _buildRepetition(sub, maxTimes - minTimes);
}
seq[seq.length - 1] = [result, false];
} }
seq[seq.length - 1] = [
_buildRepetition(subIsLiteral ? `"${sub}"` : sub, minTimes, maxTimes, {itemRuleIsLiteral: subIsLiteral}),
false
];
} else { } else {
let literal = ''; let literal = '';
while (i < length) { while (i < length) {
@ -422,23 +449,9 @@ export class SchemaConverter {
); );
} else { } else {
const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`); const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`);
const listItemOperator = `( "," space ${itemRuleName} )`; const minItems = schema.minItems || 0;
let successiveItems = '';
let minItems = schema.minItems || 0;
const maxItems = schema.maxItems; const maxItems = schema.maxItems;
if (minItems > 0) { return this._addRule(ruleName, '"[" space ' + _buildRepetition(itemRuleName, minItems, maxItems, {separatorRule: '"," space'}) + ' "]" space');
successiveItems = listItemOperator.repeat(minItems - 1);
minItems--;
}
if (maxItems !== undefined && maxItems > minItems) {
successiveItems += _buildRepetition(listItemOperator, maxItems - minItems - 1);
} else {
successiveItems += `${listItemOperator}*`;
}
const rule = minItems === 0
? `"[" space ( ${itemRuleName} ${successiveItems} )? "]" space`
: `"[" space ${itemRuleName} ${successiveItems} "]" space`;
return this._addRule(ruleName, rule);
} }
} else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) { } else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) {
return this._visitPattern(schema.pattern, ruleName); return this._visitPattern(schema.pattern, ruleName);

View file

@ -282,7 +282,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""", })""",
R"""( R"""(
boolean ::= ("true" | "false") space boolean ::= ("true" | "false") space
root ::= "[" space boolean ( "," space boolean )( "," space boolean )* "]" space root ::= "[" space boolean "," space boolean ("," space boolean)* "]" space
space ::= " "? space ::= " "?
)""" )"""
}); });
@ -298,7 +298,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""", })""",
R"""( R"""(
boolean ::= ("true" | "false") space boolean ::= ("true" | "false") space
root ::= "[" space ( boolean )? "]" space root ::= "[" space (boolean)? "]" space
space ::= " "? space ::= " "?
)""" )"""
}); });
@ -314,7 +314,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""", })""",
R"""( R"""(
boolean ::= ("true" | "false") space boolean ::= ("true" | "false") space
root ::= "[" space ( boolean (( "," space boolean ))? )? "]" space root ::= "[" space (boolean ("," space boolean)?)? "]" space
space ::= " "? space ::= " "?
)""" )"""
}); });
@ -335,7 +335,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)?
item ::= number | integer item ::= number | integer
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space item ( "," space item )( "," space item )(( "," space item ) (( "," space item ))?)? "]" space root ::= "[" space item "," space item "," space item ("," space item ("," space item)?)? "]" space
space ::= " "? space ::= " "?
)""" )"""
}); });
@ -384,11 +384,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"regexp", "regexp",
R"""({ R"""({
"type": "string", "type": "string",
"pattern": "^(\\([0-9]{1,3}\\))?[0-9]{3}-[0-9]{4} and...$" "pattern": "^(\\([0-9]{1,3}\\))?[0-9]{3}-[0-9]{4} a{3,5}nd...$"
})""", })""",
R"""( R"""(
dot ::= [\U00000000-\x09\x0B\x0C\x0E-\U0010FFFF] dot ::= [\U00000000-\x09\x0B\x0C\x0E-\U0010FFFF]
root ::= "\"" ("(" root-1 (root-1 (root-1)?)? ")")? root-1 root-1 root-1 "-" root-1 root-1 root-1 root-1 " and" dot dot dot "\"" space root ::= "\"" ("(" root-1 (root-1 (root-1)?)? ")")? root-1 root-1 root-1 "-" root-1 root-1 root-1 root-1 " " "aaa" ("a" ("a")?)? "nd" dot dot dot "\"" space
root-1 ::= [0-9] root-1 ::= [0-9]
space ::= " "? space ::= " "?
)""" )"""
@ -511,7 +511,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""( R"""(
additional-kv ::= string ":" space additional-value additional-kv ::= string ":" space additional-value
additional-kvs ::= additional-kv ( "," space additional-kv )* additional-kvs ::= additional-kv ( "," space additional-kv )*
additional-value ::= "[" space ( number ( "," space number )* )? "]" space additional-value ::= "[" space (number ("," space number)*)? "]" space
decimal-part ::= [0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? decimal-part ::= [0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)?
integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)? integral-part ::= [0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9])?)?)?)?)?)?)?)?)?)?)?)?)?)?)?
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space