diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 8a8496c31..783290eb5 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -199,7 +199,8 @@ namespace grammar_parser { sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, star_rule_id}); } else { uint32_t last_rec_rule_id = 0; - for (int i = 0, n = max_times - min_times; i < n; i++) { + auto n_opt = max_times - min_times; + for (int i = 0; i < n_opt; i++) { uint32_t rec_rule_id = generate_symbol_id(state, rule_name + "_" + std::to_string(i + 1)); if (i == 0) { add_rule(state, rec_rule_id, { @@ -217,7 +218,9 @@ namespace grammar_parser { } last_rec_rule_id = rec_rule_id; } - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + if (n_opt > 0) { + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } } sub_rule.push_back({LLAMA_GRETYPE_END, 0}); add_rule(state, sub_rule_id, sub_rule); @@ -291,24 +294,34 @@ namespace grammar_parser { pos = parse_space(pos + 1, is_nested); size_t min_times = 0; int max_times = -1; + if (is_digit_char(*pos)) { const char * int_end = parse_int(pos); min_times = std::stoul(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); + } else if (*pos != ',') { + throw std::runtime_error(std::string("expecting an int or ',' at ") + pos); } - if (*pos != ',') { + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { throw std::runtime_error(std::string("expecting ',' at ") + pos); } - pos = parse_space(pos + 1, is_nested); - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); handle_repetitions(min_times, max_times); } else { break;