From 8d37755bdc22b0bc2d6876cade32fa1534e75d3d Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 18 Jul 2023 21:54:44 -0400 Subject: [PATCH] add inverse char ranges --- examples/grammar-parser.cpp | 17 +++++++- grammars/json.gbnf | 6 +-- grammars/list.gbnf | 4 ++ llama.cpp | 80 +++++++++++++++++-------------------- llama.h | 7 +++- 5 files changed, 63 insertions(+), 51 deletions(-) create mode 100644 grammars/list.gbnf diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 5ac0d41f8..019d5e1bf 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -149,13 +149,18 @@ namespace grammar_parser { pos = parse_space(pos + 1, is_nested); } else if (*pos == '[') { // char range(s) pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } last_sym_start = out_elements.size(); while (*pos != ']') { auto char_pair = parse_char(pos); pos = char_pair.second; enum llama_gretype type = last_sym_start < out_elements.size() ? LLAMA_GRETYPE_CHAR_ALT - : LLAMA_GRETYPE_CHAR; + : start_type; out_elements.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { @@ -292,6 +297,7 @@ namespace grammar_parser { bool is_char_element(llama_grammar_element elem) { switch (elem.type) { case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; default: return false; @@ -305,8 +311,9 @@ namespace grammar_parser { case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -315,6 +322,7 @@ namespace grammar_parser { fprintf(file, "(%u) ", elem.value); break; case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "(\""); @@ -353,6 +361,10 @@ namespace grammar_parser { fprintf(file, "["); print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: if (i == 0 || !is_char_element(rule[i - 1])) { throw std::runtime_error( @@ -394,6 +406,7 @@ namespace grammar_parser { // fprintf(file, "%zu: ", i); // print_rule_binary(file, state.rules[i]); print_rule(file, i, state.rules[i], symbol_id_names); + // fprintf(file, "\n"); } } catch (const std::exception & err) { fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); diff --git a/grammars/json.gbnf b/grammars/json.gbnf index 668836df8..40fa2b637 100644 --- a/grammars/json.gbnf +++ b/grammars/json.gbnf @@ -1,7 +1,7 @@ # Grammar for subset of JSON - doesn't support full string or number syntax root ::= object -value ::= object | array | string | number | boolean +value ::= object | array | string | number | boolean | "null" object ::= "{" ws ( @@ -17,9 +17,9 @@ array ::= string ::= "\"" ( - [\x20\x21\x23-\x5b\x5d-\U0010FFFF] | # any code point except " (\x22) and \ (\x5c) + [^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes - )* "\"" + )* "\"" ws # Only plain integers currently number ::= "-"? [0-9]+ ws diff --git a/grammars/list.gbnf b/grammars/list.gbnf new file mode 100644 index 000000000..51e6c9c4b --- /dev/null +++ b/grammars/list.gbnf @@ -0,0 +1,4 @@ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" diff --git a/llama.cpp b/llama.cpp index 8cd13807d..0f6eddb50 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1925,6 +1925,31 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) } } +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -1969,6 +1994,7 @@ static void llama_grammar_advance_stack( break; } case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: new_stacks.push_back(stack); break; default: @@ -1995,34 +2021,17 @@ static std::vector> llama_grammar_acc continue; } - const llama_grammar_element * pos = stack.back(); - LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; - bool found = false; - do { - bool matches_range; - if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { - // inclusive range, e.g. [a-z] - matches_range = pos->value <= chr && chr <= pos[1].value; - pos += 2; - } else { - // exact char match, e.g. [a] or "a" - matches_range = pos->value == chr; - pos += 1; + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); } - found = found || matches_range; - } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); - - if (!found) { - continue; + llama_grammar_advance_stack(rules, new_stack, new_stacks); } - - // update top of stack to next element, if any - std::vector new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); } return new_stacks; @@ -2038,25 +2047,8 @@ static bool llama_grammar_peek( if (!chr) { return true; } - } else { - const llama_grammar_element * pos = stack.back(); - LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); - - do { - if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { - // inclusive range, e.g. [a-z] - if (pos->value <= chr && chr <= pos[1].value) { - return true; - } - pos += 2; - } else { - // exact char match, e.g. [a] or "a" - if (pos->value == chr) { - return true; - } - pos += 1; - } - } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + } else if (llama_grammar_match_char(stack.back(), chr).first) { + return true; } } return false; diff --git a/llama.h b/llama.h index 9d584b130..c9b221038 100644 --- a/llama.h +++ b/llama.h @@ -151,13 +151,16 @@ extern "C" { // terminal element: character (code point) LLAMA_GRETYPE_CHAR = 3, + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 4, + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 5, + LLAMA_GRETYPE_CHAR_ALT = 6, }; typedef struct llama_grammar_element {