grammars: disallow a{,} (not allowed in regexps)

This commit is contained in:
Olivier Chafik 2024-04-12 18:10:50 +01:00
parent a9351b8f75
commit 9d8efa545f
3 changed files with 29 additions and 10 deletions

View file

@ -187,7 +187,7 @@ namespace grammar_parser {
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<llama_grammar_element> sub_rule;
for (size_t i = 0; i < min_times; i++) {
for (int i = 0; i < min_times; i++) {
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, content_rule_id});
}
if (max_times < 0) {
@ -294,16 +294,15 @@ namespace grammar_parser {
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
size_t min_times = 0;
int max_times = -1;
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
} else if (*pos != ',') {
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int or ',' at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;

View file

@ -64,8 +64,8 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte
- `?` makes the preceding symbol or sequence optional (equivalent to `{0,1}`).
- `{m}` repeats the precedent symbol or sequence exactly `m` times
- `{m,}` repeats the precedent symbol or sequence at least `m` times
- `{m,n}` repeats the precedent symbol or sequence at betwen `m` and `n` times (included)
- `{,n}` repeats the precedent symbol or sequence at most `n` times (included)
- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included)
- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included)
## Comments and newlines

View file

@ -69,6 +69,12 @@ static void verify_parsing(const char *grammar_bytes, const std::vector<std::pai
fprintf(stderr, "Testing grammar:%s\n", grammar_bytes);
if (parsed_grammar.symbol_ids.size() != expected.size()) {
fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n");
print_all();
assert(parsed_grammar.symbol_ids.size() == expected.size());
}
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
{
std::string key = it->first;
@ -118,6 +124,12 @@ static void verify_parsing(const char *grammar_bytes, const std::vector<std::pai
}
}
static void verify_failure(const char *grammar_bytes) {
fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
auto result = grammar_parser::parse(grammar_bytes);
assert(result.rules.empty() && "should have failed");
}
int main()
{
verify_parsing(R"""(
@ -289,6 +301,14 @@ int main()
{LLAMA_GRETYPE_END, 0},
});
verify_failure(R"""(
root ::= "a"{,}"
)""");
verify_failure(R"""(
root ::= "a"{,10}"
)""");
verify_parsing(R"""(
root ::= (expr "=" term "\n")+
expr ::= term ([-+*/] term)*