diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 2a1301569..8a8496c31 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -46,8 +46,12 @@ namespace grammar_parser { state.rules[rule_id] = rule; } + static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; + } + static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } static std::pair parse_hex(const char * src, int size) { @@ -99,6 +103,17 @@ namespace grammar_parser { return pos; } + static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; + } + static std::pair parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { @@ -137,6 +152,81 @@ namespace grammar_parser { bool is_nested) { size_t last_sym_start = out_elements.size(); const char * pos = src; + + auto handle_repetitions = [&](size_t min_times, int max_times) { + + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // S* --> S{0,} + // S+ --> S{1,} + // S? --> S{0,1} + // S{m,n} --> S' ::= Scopy Scopy Scopy... (m times) S(n-m) + // Scopy ::= S + // S(x) ::= Scopy S(x-1) | + // S(x-1) ::= Scopy S(x-2) | + // S(1) ::= Scopy | + // S{m,} --> S' ::= Scopy Scopy Scopy (m times) Sstar + // Scopy ::= S + // Sstar ::= Scopy Sstar | + + uint32_t content_rule_id = 0; + if (out_elements[last_sym_start].type == LLAMA_GRETYPE_RULE_REF) { + // The repeated content is already a rule ref, no need to copy it + content_rule_id = out_elements[last_sym_start].value; + } else { + content_rule_id = generate_symbol_id(state, rule_name); + // add preceding symbol to generated copy rule + std::vector copy_rule(out_elements.begin() + last_sym_start, out_elements.end()); + copy_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, content_rule_id, copy_rule); + } + + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector sub_rule; + for (size_t i = 0; i < min_times; i++) { + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, content_rule_id}); + } + if (max_times < 0) { + uint32_t star_rule_id = generate_symbol_id(state, rule_name + "_star"); + add_rule(state, star_rule_id, { + {LLAMA_GRETYPE_RULE_REF, content_rule_id}, + {LLAMA_GRETYPE_RULE_REF, star_rule_id}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0} + }); + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, star_rule_id}); + } else { + uint32_t last_rec_rule_id = 0; + for (int i = 0, n = max_times - min_times; i < n; i++) { + uint32_t rec_rule_id = generate_symbol_id(state, rule_name + "_" + std::to_string(i + 1)); + if (i == 0) { + add_rule(state, rec_rule_id, { + {LLAMA_GRETYPE_RULE_REF, content_rule_id}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0} + }); + } else { + add_rule(state, rec_rule_id, { + {LLAMA_GRETYPE_RULE_REF, content_rule_id}, + {LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0} + }); + } + last_rec_rule_id = rec_rule_id; + } + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); + + // in original rule, replace previous symbol with reference to generated rule + out_elements.resize(last_sym_start); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + }; + while (*pos) { if (*pos == '"') { // literal string pos++; @@ -188,40 +278,38 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // rewrite rules: - // S* --> S' ::= S S' | - // S+ --> S' ::= S S' | S - // S? --> S' ::= S | - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - std::vector sub_rule; - // add preceding symbol to generated rule - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - if (*pos == '*' || *pos == '+') { - // cause generated rule to recurse - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - } - // mark start of alternate def - sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if (*pos == '+') { - // add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - } - sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, sub_rule_id, sub_rule); - - // in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - + } else if (*pos == '*') { pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + size_t min_times = 0; + int max_times = -1; + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + if (*pos != ',') { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + handle_repetitions(min_times, max_times); } else { break; }