grammars: x{min,max} repetition operator + tweak +/*/? to avoid duplication of original over alternates

This commit is contained in:
Olivier Chafik 2024-04-12 15:11:47 +01:00
parent 4cc120c744
commit 01604690c1

View file

@ -46,8 +46,12 @@ namespace grammar_parser {
state.rules[rule_id] = rule; state.rules[rule_id] = rule;
} }
static bool is_digit_char(char c) {
return '0' <= c && c <= '9';
}
static bool is_word_char(char c) { static bool is_word_char(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
} }
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
@ -99,6 +103,17 @@ namespace grammar_parser {
return pos; return pos;
} }
static const char * parse_int(const char * src) {
const char * pos = src;
while (is_digit_char(*pos)) {
pos++;
}
if (pos == src) {
throw std::runtime_error(std::string("expecting name at ") + src);
}
return pos;
}
static std::pair<uint32_t, const char *> parse_char(const char * src) { static std::pair<uint32_t, const char *> parse_char(const char * src) {
if (*src == '\\') { if (*src == '\\') {
switch (src[1]) { switch (src[1]) {
@ -137,6 +152,81 @@ namespace grammar_parser {
bool is_nested) { bool is_nested) {
size_t last_sym_start = out_elements.size(); size_t last_sym_start = out_elements.size();
const char * pos = src; const char * pos = src;
auto handle_repetitions = [&](size_t min_times, int max_times) {
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
// S* --> S{0,}
// S+ --> S{1,}
// S? --> S{0,1}
// S{m,n} --> S' ::= Scopy Scopy Scopy... (m times) S(n-m)
// Scopy ::= S
// S(x) ::= Scopy S(x-1) |
// S(x-1) ::= Scopy S(x-2) |
// S(1) ::= Scopy |
// S{m,} --> S' ::= Scopy Scopy Scopy (m times) Sstar
// Scopy ::= S
// Sstar ::= Scopy Sstar |
uint32_t content_rule_id = 0;
if (out_elements[last_sym_start].type == LLAMA_GRETYPE_RULE_REF) {
// The repeated content is already a rule ref, no need to copy it
content_rule_id = out_elements[last_sym_start].value;
} else {
content_rule_id = generate_symbol_id(state, rule_name);
// add preceding symbol to generated copy rule
std::vector<llama_grammar_element> copy_rule(out_elements.begin() + last_sym_start, out_elements.end());
copy_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, content_rule_id, copy_rule);
}
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<llama_grammar_element> sub_rule;
for (size_t i = 0; i < min_times; i++) {
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, content_rule_id});
}
if (max_times < 0) {
uint32_t star_rule_id = generate_symbol_id(state, rule_name + "_star");
add_rule(state, star_rule_id, {
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
{LLAMA_GRETYPE_RULE_REF, star_rule_id},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0}
});
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, star_rule_id});
} else {
uint32_t last_rec_rule_id = 0;
for (int i = 0, n = max_times - min_times; i < n; i++) {
uint32_t rec_rule_id = generate_symbol_id(state, rule_name + "_" + std::to_string(i + 1));
if (i == 0) {
add_rule(state, rec_rule_id, {
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0}
});
} else {
add_rule(state, rec_rule_id, {
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
{LLAMA_GRETYPE_RULE_REF, last_rec_rule_id},
{LLAMA_GRETYPE_ALT, 0},
{LLAMA_GRETYPE_END, 0}
});
}
last_rec_rule_id = rec_rule_id;
}
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, sub_rule_id, sub_rule);
// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
};
while (*pos) { while (*pos) {
if (*pos == '"') { // literal string if (*pos == '"') { // literal string
pos++; pos++;
@ -188,40 +278,38 @@ namespace grammar_parser {
throw std::runtime_error(std::string("expecting ')' at ") + pos); throw std::runtime_error(std::string("expecting ')' at ") + pos);
} }
pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator } else if (*pos == '*') {
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<llama_grammar_element> sub_rule;
// add preceding symbol to generated rule
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
if (*pos == '*' || *pos == '+') {
// cause generated rule to recurse
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
}
// mark start of alternate def
sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
if (*pos == '+') {
// add preceding symbol as alternate only for '+' (otherwise empty)
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
}
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, sub_rule_id, sub_rule);
// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
size_t min_times = 0;
int max_times = -1;
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != ',') {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
handle_repetitions(min_times, max_times);
} else { } else {
break; break;
} }