json: remove recursion in opt_repetitions (avoids Python stack overflow)

This commit is contained in:
ochafik 2024-04-12 10:03:27 +01:00
parent 958bdda559
commit 64e305901e
4 changed files with 1740 additions and 1692 deletions

View file

@ -8,6 +8,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ggml.h"
using json = nlohmann::ordered_json;
@ -36,6 +37,23 @@ static std::string build_repetition(const std::string & item_rule, int min_items
}
std::function<std::string(int, bool)> opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string {
auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule;
if (up_to_n == 0) {
return "";
} else if (up_to_n == 1) {
return "(" + content + ")?";
} else if (!separator_rule.empty() && !prefix_with_sep) {
return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?";
} else {
std::string res = repeat("(" + content + " ", up_to_n);
// strip trailing space
GGML_ASSERT(!res.empty());
res = res.substr(0, res.length() - 1);
res += repeat(")?", up_to_n);
return res;
}
if (up_to_n == 0) {
return "";
}

View file

@ -22,13 +22,21 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
def opt_repetitions(up_to_n, prefix_with_sep=False):
'''
- n=4, no sep: '(a (a (a (a)?)?)?)?'
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
'''
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
if up_to_n == 0:
return ''
res = separator_rule + ' ' + item_rule if separator_rule and prefix_with_sep else item_rule
if up_to_n > 1:
res += ' ' + opt_repetitions(up_to_n - 1, prefix_with_sep=True)
return f'({res})?'
elif up_to_n == 1:
return f'({content})?'
elif separator_rule and not prefix_with_sep:
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
else:
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
if min_items > 0 and max_items != min_items:
result += ' '

File diff suppressed because it is too large Load diff

View file

@ -24,14 +24,16 @@ function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
}
const optRepetitions = (upToN, prefixWithSep=false) => {
const content = separatorRule !== '' && prefixWithSep ? `${separatorRule} ${itemRule}` : itemRule;
if (upToN === 0) {
return '';
} else if (upToN === 1) {
return `(${content})?`;
} else if (separatorRule !== '' && !prefixWithSep) {
return `(${content} ${optRepetitions(upToN - 1, true)})?`;
} else {
return Array.from({ length: upToN }, () => `(${content}`).join(' ').trim() + Array.from({ length: upToN }, () => ')?').join('');
}
let res = separatorRule !== '' && prefixWithSep ? separatorRule + ' ' + itemRule : itemRule;
if (upToN > 1) {
res += ' ' + optRepetitions(upToN - 1, true);
}
return `(${res})?`;
};
if (minItems > 0 && maxItems !== minItems) {