add inverse char ranges

This commit is contained in:
Evan Jones 2023-07-18 21:54:44 -04:00
parent b2e071dd86
commit 8d37755bdc
5 changed files with 63 additions and 51 deletions

View file

@ -149,13 +149,18 @@ namespace grammar_parser {
pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s) } else if (*pos == '[') { // char range(s)
pos++; pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = out_elements.size(); last_sym_start = out_elements.size();
while (*pos != ']') { while (*pos != ']') {
auto char_pair = parse_char(pos); auto char_pair = parse_char(pos);
pos = char_pair.second; pos = char_pair.second;
enum llama_gretype type = last_sym_start < out_elements.size() enum llama_gretype type = last_sym_start < out_elements.size()
? LLAMA_GRETYPE_CHAR_ALT ? LLAMA_GRETYPE_CHAR_ALT
: LLAMA_GRETYPE_CHAR; : start_type;
out_elements.push_back({type, char_pair.first}); out_elements.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') { if (pos[0] == '-' && pos[1] != ']') {
@ -292,6 +297,7 @@ namespace grammar_parser {
bool is_char_element(llama_grammar_element elem) { bool is_char_element(llama_grammar_element elem) {
switch (elem.type) { switch (elem.type) {
case LLAMA_GRETYPE_CHAR: return true; case LLAMA_GRETYPE_CHAR: return true;
case LLAMA_GRETYPE_CHAR_NOT: return true;
case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
default: return false; default: return false;
@ -305,8 +311,9 @@ namespace grammar_parser {
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
} }
switch (elem.type) { switch (elem.type) {
case LLAMA_GRETYPE_END: case LLAMA_GRETYPE_END:
@ -315,6 +322,7 @@ namespace grammar_parser {
fprintf(file, "(%u) ", elem.value); fprintf(file, "(%u) ", elem.value);
break; break;
case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_RNG_UPPER:
case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_ALT:
fprintf(file, "(\""); fprintf(file, "(\"");
@ -353,6 +361,10 @@ namespace grammar_parser {
fprintf(file, "["); fprintf(file, "[");
print_grammar_char(file, elem.value); print_grammar_char(file, elem.value);
break; break;
case LLAMA_GRETYPE_CHAR_NOT:
fprintf(file, "[^");
print_grammar_char(file, elem.value);
break;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_RNG_UPPER:
if (i == 0 || !is_char_element(rule[i - 1])) { if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error( throw std::runtime_error(
@ -394,6 +406,7 @@ namespace grammar_parser {
// fprintf(file, "%zu: ", i); // fprintf(file, "%zu: ", i);
// print_rule_binary(file, state.rules[i]); // print_rule_binary(file, state.rules[i]);
print_rule(file, i, state.rules[i], symbol_id_names); print_rule(file, i, state.rules[i], symbol_id_names);
// fprintf(file, "\n");
} }
} catch (const std::exception & err) { } catch (const std::exception & err) {
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());

View file

@ -1,7 +1,7 @@
# Grammar for subset of JSON - doesn't support full string or number syntax # Grammar for subset of JSON - doesn't support full string or number syntax
root ::= object root ::= object
value ::= object | array | string | number | boolean value ::= object | array | string | number | boolean | "null"
object ::= object ::=
"{" ws ( "{" ws (
@ -17,9 +17,9 @@ array ::=
string ::= string ::=
"\"" ( "\"" (
[\x20\x21\x23-\x5b\x5d-\U0010FFFF] | # any code point except " (\x22) and \ (\x5c) [^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" )* "\"" ws
# Only plain integers currently # Only plain integers currently
number ::= "-"? [0-9]+ ws number ::= "-"? [0-9]+ ws

4
grammars/list.gbnf Normal file
View file

@ -0,0 +1,4 @@
root ::= item+
# Excludes various line break characters
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"

View file

@ -1925,6 +1925,31 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
} }
} }
// returns true iff chr satisfies the char range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
const llama_grammar_element * pos,
const uint32_t chr) {
bool found = false;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
do {
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
found = found || (pos->value <= chr && chr <= pos[1].value);
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
found = found || pos->value == chr;
pos += 1;
}
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
return std::make_pair(found == is_positive_char, pos);
}
// transforms a grammar pushdown stack into N possible stacks, all ending // transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element) // at a character range (terminal element)
static void llama_grammar_advance_stack( static void llama_grammar_advance_stack(
@ -1969,6 +1994,7 @@ static void llama_grammar_advance_stack(
break; break;
} }
case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
new_stacks.push_back(stack); new_stacks.push_back(stack);
break; break;
default: default:
@ -1995,27 +2021,9 @@ static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_acc
continue; continue;
} }
const llama_grammar_element * pos = stack.back(); auto match = llama_grammar_match_char(stack.back(), chr);
LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); if (match.first) {
const llama_grammar_element * pos = match.second;
bool found = false;
do {
bool matches_range;
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
matches_range = pos->value <= chr && chr <= pos[1].value;
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
matches_range = pos->value == chr;
pos += 1;
}
found = found || matches_range;
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
if (!found) {
continue;
}
// update top of stack to next element, if any // update top of stack to next element, if any
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1); std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
@ -2024,6 +2032,7 @@ static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_acc
} }
llama_grammar_advance_stack(rules, new_stack, new_stacks); llama_grammar_advance_stack(rules, new_stack, new_stacks);
} }
}
return new_stacks; return new_stacks;
} }
@ -2038,26 +2047,9 @@ static bool llama_grammar_peek(
if (!chr) { if (!chr) {
return true; return true;
} }
} else { } else if (llama_grammar_match_char(stack.back(), chr).first) {
const llama_grammar_element * pos = stack.back();
LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR);
do {
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
if (pos->value <= chr && chr <= pos[1].value) {
return true; return true;
} }
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
if (pos->value == chr) {
return true;
}
pos += 1;
}
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
}
} }
return false; return false;
} }

View file

@ -151,13 +151,16 @@ extern "C" {
// terminal element: character (code point) // terminal element: character (code point)
LLAMA_GRETYPE_CHAR = 3, LLAMA_GRETYPE_CHAR = 3,
// inverse char(s) ([^a], [^a-b] [^abc])
LLAMA_GRETYPE_CHAR_NOT = 4,
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
// be an inclusive range ([a-z]) // be an inclusive range ([a-z])
LLAMA_GRETYPE_CHAR_RNG_UPPER = 4, LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
// modifies a preceding LLAMA_GRETYPE_CHAR or // modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 5, LLAMA_GRETYPE_CHAR_ALT = 6,
}; };
typedef struct llama_grammar_element { typedef struct llama_grammar_element {