grammars: x{min,max} repetition operator + tweak +/*/? to avoid duplication of original over alternates
This commit is contained in:
parent
4cc120c744
commit
01604690c1
1 changed files with 122 additions and 34 deletions
|
@ -46,8 +46,12 @@ namespace grammar_parser {
|
|||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
static bool is_digit_char(char c) {
|
||||
return '0' <= c && c <= '9';
|
||||
}
|
||||
|
||||
static bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
||||
}
|
||||
|
||||
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
|
@ -99,6 +103,17 @@ namespace grammar_parser {
|
|||
return pos;
|
||||
}
|
||||
|
||||
static const char * parse_int(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_digit_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
|
@ -137,6 +152,81 @@ namespace grammar_parser {
|
|||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
|
||||
auto handle_repetitions = [&](size_t min_times, int max_times) {
|
||||
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||
}
|
||||
|
||||
// S* --> S{0,}
|
||||
// S+ --> S{1,}
|
||||
// S? --> S{0,1}
|
||||
// S{m,n} --> S' ::= Scopy Scopy Scopy... (m times) S(n-m)
|
||||
// Scopy ::= S
|
||||
// S(x) ::= Scopy S(x-1) |
|
||||
// S(x-1) ::= Scopy S(x-2) |
|
||||
// S(1) ::= Scopy |
|
||||
// S{m,} --> S' ::= Scopy Scopy Scopy (m times) Sstar
|
||||
// Scopy ::= S
|
||||
// Sstar ::= Scopy Sstar |
|
||||
|
||||
uint32_t content_rule_id = 0;
|
||||
if (out_elements[last_sym_start].type == LLAMA_GRETYPE_RULE_REF) {
|
||||
// The repeated content is already a rule ref, no need to copy it
|
||||
content_rule_id = out_elements[last_sym_start].value;
|
||||
} else {
|
||||
content_rule_id = generate_symbol_id(state, rule_name);
|
||||
// add preceding symbol to generated copy rule
|
||||
std::vector<llama_grammar_element> copy_rule(out_elements.begin() + last_sym_start, out_elements.end());
|
||||
copy_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
add_rule(state, content_rule_id, copy_rule);
|
||||
}
|
||||
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<llama_grammar_element> sub_rule;
|
||||
for (size_t i = 0; i < min_times; i++) {
|
||||
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, content_rule_id});
|
||||
}
|
||||
if (max_times < 0) {
|
||||
uint32_t star_rule_id = generate_symbol_id(state, rule_name + "_star");
|
||||
add_rule(state, star_rule_id, {
|
||||
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
|
||||
{LLAMA_GRETYPE_RULE_REF, star_rule_id},
|
||||
{LLAMA_GRETYPE_ALT, 0},
|
||||
{LLAMA_GRETYPE_END, 0}
|
||||
});
|
||||
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, star_rule_id});
|
||||
} else {
|
||||
uint32_t last_rec_rule_id = 0;
|
||||
for (int i = 0, n = max_times - min_times; i < n; i++) {
|
||||
uint32_t rec_rule_id = generate_symbol_id(state, rule_name + "_" + std::to_string(i + 1));
|
||||
if (i == 0) {
|
||||
add_rule(state, rec_rule_id, {
|
||||
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
|
||||
{LLAMA_GRETYPE_ALT, 0},
|
||||
{LLAMA_GRETYPE_END, 0}
|
||||
});
|
||||
} else {
|
||||
add_rule(state, rec_rule_id, {
|
||||
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
|
||||
{LLAMA_GRETYPE_RULE_REF, last_rec_rule_id},
|
||||
{LLAMA_GRETYPE_ALT, 0},
|
||||
{LLAMA_GRETYPE_END, 0}
|
||||
});
|
||||
}
|
||||
last_rec_rule_id = rec_rule_id;
|
||||
}
|
||||
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||
}
|
||||
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||
};
|
||||
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
|
@ -188,40 +278,38 @@ namespace grammar_parser {
|
|||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// rewrite rules:
|
||||
// S* --> S' ::= S S' |
|
||||
// S+ --> S' ::= S S' | S
|
||||
// S? --> S' ::= S |
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<llama_grammar_element> sub_rule;
|
||||
// add preceding symbol to generated rule
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (*pos == '*' || *pos == '+') {
|
||||
// cause generated rule to recurse
|
||||
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||
}
|
||||
// mark start of alternate def
|
||||
sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||
if (*pos == '+') {
|
||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
}
|
||||
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
||||
} else if (*pos == '*') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(0, -1);
|
||||
} else if (*pos == '+') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(1, -1);
|
||||
} else if (*pos == '?') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(0, 1);
|
||||
} else if (*pos == '{') {
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
size_t min_times = 0;
|
||||
int max_times = -1;
|
||||
if (is_digit_char(*pos)) {
|
||||
const char * int_end = parse_int(pos);
|
||||
min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
}
|
||||
if (*pos != ',') {
|
||||
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
if (is_digit_char(*pos)) {
|
||||
const char * int_end = parse_int(pos);
|
||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
}
|
||||
if (*pos != '}') {
|
||||
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
handle_repetitions(min_times, max_times);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue