add inverse char ranges
This commit is contained in:
parent
b2e071dd86
commit
8d37755bdc
5 changed files with 63 additions and 51 deletions
|
@ -149,13 +149,18 @@ namespace grammar_parser {
|
||||||
pos = parse_space(pos + 1, is_nested);
|
pos = parse_space(pos + 1, is_nested);
|
||||||
} else if (*pos == '[') { // char range(s)
|
} else if (*pos == '[') { // char range(s)
|
||||||
pos++;
|
pos++;
|
||||||
|
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||||
|
if (*pos == '^') {
|
||||||
|
pos++;
|
||||||
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
|
}
|
||||||
last_sym_start = out_elements.size();
|
last_sym_start = out_elements.size();
|
||||||
while (*pos != ']') {
|
while (*pos != ']') {
|
||||||
auto char_pair = parse_char(pos);
|
auto char_pair = parse_char(pos);
|
||||||
pos = char_pair.second;
|
pos = char_pair.second;
|
||||||
enum llama_gretype type = last_sym_start < out_elements.size()
|
enum llama_gretype type = last_sym_start < out_elements.size()
|
||||||
? LLAMA_GRETYPE_CHAR_ALT
|
? LLAMA_GRETYPE_CHAR_ALT
|
||||||
: LLAMA_GRETYPE_CHAR;
|
: start_type;
|
||||||
|
|
||||||
out_elements.push_back({type, char_pair.first});
|
out_elements.push_back({type, char_pair.first});
|
||||||
if (pos[0] == '-' && pos[1] != ']') {
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
|
@ -292,6 +297,7 @@ namespace grammar_parser {
|
||||||
bool is_char_element(llama_grammar_element elem) {
|
bool is_char_element(llama_grammar_element elem) {
|
||||||
switch (elem.type) {
|
switch (elem.type) {
|
||||||
case LLAMA_GRETYPE_CHAR: return true;
|
case LLAMA_GRETYPE_CHAR: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
||||||
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||||
default: return false;
|
default: return false;
|
||||||
|
@ -305,8 +311,9 @@ namespace grammar_parser {
|
||||||
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||||
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||||
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||||
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_RNG_UPPER"); break;
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||||
}
|
}
|
||||||
switch (elem.type) {
|
switch (elem.type) {
|
||||||
case LLAMA_GRETYPE_END:
|
case LLAMA_GRETYPE_END:
|
||||||
|
@ -315,6 +322,7 @@ namespace grammar_parser {
|
||||||
fprintf(file, "(%u) ", elem.value);
|
fprintf(file, "(%u) ", elem.value);
|
||||||
break;
|
break;
|
||||||
case LLAMA_GRETYPE_CHAR:
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
case LLAMA_GRETYPE_CHAR_ALT:
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
fprintf(file, "(\"");
|
fprintf(file, "(\"");
|
||||||
|
@ -353,6 +361,10 @@ namespace grammar_parser {
|
||||||
fprintf(file, "[");
|
fprintf(file, "[");
|
||||||
print_grammar_char(file, elem.value);
|
print_grammar_char(file, elem.value);
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
fprintf(file, "[^");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
@ -394,6 +406,7 @@ namespace grammar_parser {
|
||||||
// fprintf(file, "%zu: ", i);
|
// fprintf(file, "%zu: ", i);
|
||||||
// print_rule_binary(file, state.rules[i]);
|
// print_rule_binary(file, state.rules[i]);
|
||||||
print_rule(file, i, state.rules[i], symbol_id_names);
|
print_rule(file, i, state.rules[i], symbol_id_names);
|
||||||
|
// fprintf(file, "\n");
|
||||||
}
|
}
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Grammar for subset of JSON - doesn't support full string or number syntax
|
# Grammar for subset of JSON - doesn't support full string or number syntax
|
||||||
|
|
||||||
root ::= object
|
root ::= object
|
||||||
value ::= object | array | string | number | boolean
|
value ::= object | array | string | number | boolean | "null"
|
||||||
|
|
||||||
object ::=
|
object ::=
|
||||||
"{" ws (
|
"{" ws (
|
||||||
|
@ -17,9 +17,9 @@ array ::=
|
||||||
|
|
||||||
string ::=
|
string ::=
|
||||||
"\"" (
|
"\"" (
|
||||||
[\x20\x21\x23-\x5b\x5d-\U0010FFFF] | # any code point except " (\x22) and \ (\x5c)
|
[^"\\] |
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
)* "\""
|
)* "\"" ws
|
||||||
|
|
||||||
# Only plain integers currently
|
# Only plain integers currently
|
||||||
number ::= "-"? [0-9]+ ws
|
number ::= "-"? [0-9]+ ws
|
||||||
|
|
4
grammars/list.gbnf
Normal file
4
grammars/list.gbnf
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
root ::= item+
|
||||||
|
|
||||||
|
# Excludes various line break characters
|
||||||
|
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
80
llama.cpp
80
llama.cpp
|
@ -1925,6 +1925,31 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
||||||
|
// asserts that pos is pointing to a char range element
|
||||||
|
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const uint32_t chr) {
|
||||||
|
|
||||||
|
bool found = false;
|
||||||
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
|
||||||
|
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
|
||||||
|
|
||||||
|
do {
|
||||||
|
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
||||||
|
// inclusive range, e.g. [a-z]
|
||||||
|
found = found || (pos->value <= chr && chr <= pos[1].value);
|
||||||
|
pos += 2;
|
||||||
|
} else {
|
||||||
|
// exact char match, e.g. [a] or "a"
|
||||||
|
found = found || pos->value == chr;
|
||||||
|
pos += 1;
|
||||||
|
}
|
||||||
|
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
||||||
|
|
||||||
|
return std::make_pair(found == is_positive_char, pos);
|
||||||
|
}
|
||||||
|
|
||||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
// at a character range (terminal element)
|
// at a character range (terminal element)
|
||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
|
@ -1969,6 +1994,7 @@ static void llama_grammar_advance_stack(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case LLAMA_GRETYPE_CHAR:
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
new_stacks.push_back(stack);
|
new_stacks.push_back(stack);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -1995,34 +2021,17 @@ static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_acc
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_grammar_element * pos = stack.back();
|
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||||
LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR);
|
if (match.first) {
|
||||||
|
const llama_grammar_element * pos = match.second;
|
||||||
|
|
||||||
bool found = false;
|
// update top of stack to next element, if any
|
||||||
do {
|
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
||||||
bool matches_range;
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
new_stack.push_back(pos);
|
||||||
// inclusive range, e.g. [a-z]
|
|
||||||
matches_range = pos->value <= chr && chr <= pos[1].value;
|
|
||||||
pos += 2;
|
|
||||||
} else {
|
|
||||||
// exact char match, e.g. [a] or "a"
|
|
||||||
matches_range = pos->value == chr;
|
|
||||||
pos += 1;
|
|
||||||
}
|
}
|
||||||
found = found || matches_range;
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
|
||||||
|
|
||||||
if (!found) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update top of stack to next element, if any
|
|
||||||
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
|
||||||
new_stack.push_back(pos);
|
|
||||||
}
|
|
||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_stacks;
|
return new_stacks;
|
||||||
|
@ -2038,25 +2047,8 @@ static bool llama_grammar_peek(
|
||||||
if (!chr) {
|
if (!chr) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} else {
|
} else if (llama_grammar_match_char(stack.back(), chr).first) {
|
||||||
const llama_grammar_element * pos = stack.back();
|
return true;
|
||||||
LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR);
|
|
||||||
|
|
||||||
do {
|
|
||||||
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
|
||||||
// inclusive range, e.g. [a-z]
|
|
||||||
if (pos->value <= chr && chr <= pos[1].value) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
pos += 2;
|
|
||||||
} else {
|
|
||||||
// exact char match, e.g. [a] or "a"
|
|
||||||
if (pos->value == chr) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
pos += 1;
|
|
||||||
}
|
|
||||||
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
7
llama.h
7
llama.h
|
@ -151,13 +151,16 @@ extern "C" {
|
||||||
// terminal element: character (code point)
|
// terminal element: character (code point)
|
||||||
LLAMA_GRETYPE_CHAR = 3,
|
LLAMA_GRETYPE_CHAR = 3,
|
||||||
|
|
||||||
|
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||||
|
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||||
|
|
||||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||||
// be an inclusive range ([a-z])
|
// be an inclusive range ([a-z])
|
||||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 4,
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||||
|
|
||||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||||
LLAMA_GRETYPE_CHAR_ALT = 5,
|
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef struct llama_grammar_element {
|
typedef struct llama_grammar_element {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue