From f8baad235d1056527a5e594c80abc35c149129f8 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 20 Jun 2023 00:06:38 -0400 Subject: [PATCH] use struct for grammar elements and add Unicode support --- examples/grammar-parser.cpp | 323 ++++++++++++++++++++++-------------- examples/grammar-parser.h | 7 +- examples/main/main.cpp | 12 +- grammars/japanese.gbnf | 7 + llama.cpp | 260 ++++++++++++++++++----------- llama.h | 58 ++++--- 6 files changed, 416 insertions(+), 251 deletions(-) create mode 100644 grammars/japanese.gbnf diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index 56dc7da2f..206ccac56 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -7,18 +7,45 @@ #include namespace grammar_parser { - uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint16_t next_id = static_cast(state.symbol_ids.size()); + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint32_t next_id = static_cast(state.symbol_ids.size()); auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); return result.first->second; } - uint16_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint16_t next_id = static_cast(state.symbol_ids.size()); + uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint32_t next_id = static_cast(state.symbol_ids.size()); state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; return next_id; } + void add_rule( + parse_state & state, + uint32_t rule_id, + const std::vector & rule) { + if (state.rules.size() <= rule_id) { + state.rules.resize(rule_id + 1); + } + state.rules[rule_id] = rule; + } + bool is_word_char(char c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } @@ -60,9 +87,10 @@ namespace grammar_parser { return pos; } - std::pair parse_char(const char * src) { + std::pair parse_char(const char * src) { if (*src == '\\') { char esc = src[1]; + // TODO: 16- and 32-bit escapes if (esc == 'x') { int first = hex_to_int(src[2]); if (first > -1) { @@ -83,7 +111,8 @@ namespace grammar_parser { } throw std::runtime_error(std::string("unknown escape at ") + src); } else if (*src) { - return std::make_pair(*src, src + 1); + auto decoded = decode_utf8(src); + return std::make_pair(decoded.first, decoded.second); } throw std::runtime_error("unexpected end of input"); } @@ -92,132 +121,101 @@ namespace grammar_parser { parse_state & state, const char * src, const std::string & rule_name, - uint16_t rule_id, + uint32_t rule_id, bool is_nested); const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & outbuf, - bool is_nested) { - size_t out_start = outbuf.size(); - - // sequence size, will be replaced at end when known - outbuf.push_back(0); - - size_t last_sym_start = outbuf.size(); + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & out_elements, + bool is_nested) { + size_t last_sym_start = out_elements.size(); const char * pos = src; while (*pos) { if (*pos == '"') { // literal string pos++; - last_sym_start = outbuf.size(); + last_sym_start = out_elements.size(); while (*pos != '"') { auto char_pair = parse_char(pos); pos = char_pair.second; - - // each char of a literal is encoded as a "range" of char - char - outbuf.push_back(2); - outbuf.push_back(char_pair.first); - outbuf.push_back(char_pair.first); + out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); } pos = parse_space(pos + 1, is_nested); } else if (*pos == '[') { // char range(s) pos++; - last_sym_start = outbuf.size(); - // num chars in range - replaced at end of loop - outbuf.push_back(0); + last_sym_start = out_elements.size(); while (*pos != ']') { auto char_pair = parse_char(pos); pos = char_pair.second; + enum llama_gretype type = last_sym_start < out_elements.size() + ? LLAMA_GRETYPE_CHAR_ALT + : LLAMA_GRETYPE_CHAR; - outbuf.push_back(char_pair.first); + out_elements.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { auto endchar_pair = parse_char(pos + 1); pos = endchar_pair.second; - outbuf.push_back(endchar_pair.first); - } else { - // chars that aren't part of a c1-c2 range are just doubled (i.e., c-c) - outbuf.push_back(char_pair.first); + out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); } } - // replace num chars with actual - outbuf[last_sym_start] = static_cast(outbuf.size() - last_sym_start - 1); pos = parse_space(pos + 1, is_nested); } else if (is_word_char(*pos)) { // rule reference const char * name_end = parse_name(pos); - uint16_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); pos = parse_space(name_end, is_nested); - last_sym_start = outbuf.size(); - outbuf.push_back(1); - outbuf.push_back(ref_rule_id); + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); - uint16_t sub_rule_id = generate_symbol_id(state, rule_name); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = outbuf.size(); + last_sym_start = out_elements.size(); // output reference to synthesized rule - outbuf.push_back(1); - outbuf.push_back(sub_rule_id); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); if (*pos != ')') { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (outbuf.size() - out_start - 1 == 0) { + if (last_sym_start == out_elements.size()) { throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); } - std::vector & out_grammar = state.out_grammar; - // apply transformation to previous symbol (last_sym_start - - // end) according to rewrite rules: + // apply transformation to previous symbol (last_sym_start to end) according to + // rewrite rules: // S* --> S' ::= S S' | // S+ --> S' ::= S S' | S // S? --> S' ::= S | - uint16_t sub_rule_id = generate_symbol_id(state, rule_name); - out_grammar.push_back(sub_rule_id); - size_t sub_rule_start = out_grammar.size(); - // placeholder for size of 1st alternate - out_grammar.push_back(0); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector sub_rule; // add preceding symbol to generated rule - out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); if (*pos == '*' || *pos == '+') { // cause generated rule to recurse - out_grammar.push_back(1); - out_grammar.push_back(sub_rule_id); + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); } - // apply actual size - out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; - // mark end of 1st alternate - out_grammar.push_back(0); - sub_rule_start = out_grammar.size(); - // placeholder for size of 2nd alternate - out_grammar.push_back(0); + // mark start of alternate def + sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); if (*pos == '+') { - // add preceding symbol as alternate only for '+' - out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + // add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); } - // apply actual size of 2nd alternate - out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; - // mark end of 2nd alternate, then end of rule - out_grammar.push_back(0); - out_grammar.push_back(0); + sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); // in original rule, replace previous symbol with reference to generated rule - outbuf.resize(last_sym_start); - outbuf.push_back(1); - outbuf.push_back(sub_rule_id); + out_elements.resize(last_sym_start); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); pos = parse_space(pos + 1, is_nested); } else { break; } } - // apply actual size of this alternate sequence - outbuf[out_start] = static_cast(outbuf.size() - out_start); - // mark end of alternate - outbuf.push_back(0); return pos; } @@ -225,17 +223,17 @@ namespace grammar_parser { parse_state & state, const char * src, const std::string & rule_name, - uint16_t rule_id, + uint32_t rule_id, bool is_nested) { - std::vector outbuf; - const char * pos = parse_sequence(state, src, rule_name, outbuf, is_nested); + std::vector rule; + const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, outbuf, is_nested); + pos = parse_sequence(state, pos, rule_name, rule, is_nested); } - state.out_grammar.push_back(rule_id); - state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end()); - state.out_grammar.push_back(0); + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rule_id, rule); return pos; } @@ -243,7 +241,7 @@ namespace grammar_parser { const char * name_end = parse_name(src); const char * pos = parse_space(name_end, false); size_t name_len = name_end - src; - uint16_t rule_id = get_symbol_id(state, src, name_len); + uint32_t rule_id = get_symbol_id(state, src, name_len); const std::string name(src, name_len); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { @@ -270,7 +268,6 @@ namespace grammar_parser { while (*pos) { pos = parse_rule(state, pos); } - state.out_grammar.push_back(0xffff); return state; } catch (const std::exception & err) { fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); @@ -278,53 +275,131 @@ namespace grammar_parser { } } - const uint16_t * print_rule( - FILE * file, - const uint16_t * base, - const uint16_t * src, - const std::map & symbol_id_names) { - uint16_t rule_id = *src; - fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str()); - const uint16_t * pos = src + 1; - while (*pos) { - if (pos - 1 > src) { - fprintf(file, "| "); - } - pos++; // sequence size, not needed here - while (*pos) { - if (*pos == 1) { - uint16_t ref_rule_id = pos[1]; - fprintf(file, "<%zu>%s ", pos - base, symbol_id_names.at(ref_rule_id).c_str()); - pos += 2; - } else { - fprintf(file, "<%zu>[", pos - base); - uint16_t num_chars = *pos; - pos++; + void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } + } - for (uint16_t i = 0; i < num_chars; i += 2) { - fprintf(file, "%lc-", static_cast(pos[i])); // REVIEW - if (i + 1 < num_chars) { - fprintf(file, "%lc", static_cast(pos[i + 1])); - } - } - fprintf(file, "] "); - pos += num_chars; - } + bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + default: return false; + } + } + + void print_rule_binary(FILE * file, const std::vector & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_RNG_UPPER"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); + } + + void print_rule( + FILE * file, + uint32_t rule_id, + const std::vector & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + break; + default: + fprintf(file, "] "); + } } - pos++; } fprintf(file, "\n"); - return pos + 1; } void print_grammar(FILE * file, const parse_state & state) { - std::map symbol_id_names; - for (auto kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - const uint16_t * pos = state.out_grammar.data(); - while (*pos != 0xffff) { - pos = print_rule(file, state.out_grammar.data(), pos, symbol_id_names); + try { + std::map symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = state.rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, state.rules[i]); + print_rule(file, i, state.rules[i], symbol_id_names); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); } } + + std::vector parse_state::c_rules() { + std::vector ret; + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; + } } diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h index c9e27d4cd..9037d7272 100644 --- a/examples/grammar-parser.h +++ b/examples/grammar-parser.h @@ -10,6 +10,7 @@ // space ::= [ \t\n]* #pragma once +#include "llama.h" #include #include #include @@ -17,8 +18,10 @@ namespace grammar_parser { struct parse_state { - std::map symbol_ids; - std::vector out_grammar; + std::map symbol_ids; + std::vector> rules; + + std::vector c_rules(); }; parse_state parse(const char * src); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7f9636aae..b6538ac13 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -300,14 +300,16 @@ int main(int argc, char ** argv) { if (!params.grammar.empty()) { parsed_grammar = grammar_parser::parse(params.grammar.c_str()); // will be empty (default) if there are parse errors - if (parsed_grammar.out_grammar.empty()) { + if (parsed_grammar.rules.empty()) { return 1; } fprintf(stderr, "%s: grammar:\n", __func__); grammar_parser::print_grammar(stderr, parsed_grammar); fprintf(stderr, "\n"); + + std::vector grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init( - parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } // TODO: replace with ring-buffer @@ -653,8 +655,12 @@ int main(int argc, char ** argv) { // reset grammar state if we're restarting generation if (grammar != NULL) { llama_grammar_free(grammar); + + std::vector grammar_rules( + parsed_grammar.c_rules()); grammar = llama_grammar_init( - parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); } } is_interacting = false; diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf new file mode 100644 index 000000000..43f25ab59 --- /dev/null +++ b/grammars/japanese.gbnf @@ -0,0 +1,7 @@ +# A probably incorrect grammar for Japanese +root ::= jp-char+ ([ \t\n] jp-char+)* +jp-char ::= hiragana | katakana | punctuation | cjk +hiragana ::= [ぁ-ゟ] +katakana ::= [ァ-ヿ] +punctuation ::= [、-〾] +cjk ::= [一-鿿] diff --git a/llama.cpp b/llama.cpp index 37d19ea91..0986c2489 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1841,45 +1841,86 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // struct llama_grammar { - const std::vector rules; - std::vector> stacks; + const std::vector> rules; + std::vector> stacks; }; -// transforms a grammar pushdown stack into N possible stacks, all terminating +// NOTE: assumes valid utf8 (but checks for overrun) +std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; // may overrun! + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; + case LLAMA_GRETYPE_ALT: return true; + default: return false; + } +} + +// transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( - const std::vector & rules, - const std::vector & stack, - std::vector> & new_stacks) { + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { if (stack.empty()) { new_stacks.push_back(stack); return; } - const uint16_t * pos = stack.back(); + const llama_grammar_element * pos = stack.back(); - if (*pos == 1) { - // rule reference, apply rule to stack - const uint16_t * subpos = rules[pos[1]] + 1; - while (*subpos) { - // init new stack without the top (pos) - std::vector new_stack(stack.begin(), stack.end() - 1); - if (pos[2]) { - // if the rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 2); - } - if (subpos[1]) { - // if this alternate is nonempty, add that to the stack - new_stack.push_back(subpos + 1); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); - subpos += 1 + *subpos; + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; } - } else { - // rule element size > 1 -> character reference - LLAMA_ASSERT(*pos); - new_stacks.push_back(stack); + case LLAMA_GRETYPE_CHAR: + new_stacks.push_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + LLAMA_ASSERT(false); } } @@ -1887,39 +1928,43 @@ static void llama_grammar_advance_stack( // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -static std::vector> llama_grammar_accept( - const std::vector & rules, - const std::vector> & stacks, - const uint16_t chr) { +static std::vector> llama_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { - std::vector> new_stacks; + std::vector> new_stacks; for (const auto & stack : stacks) { if (stack.empty()) { continue; } - const uint16_t * pos = stack.back(); - const uint16_t num_chars = *pos; - LLAMA_ASSERT(num_chars > 1); + const llama_grammar_element * pos = stack.back(); + LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); - pos++; // skip num chars indicator bool found = false; - // loop over the inclusive char pairs to find a match on the given char - for (int i = 0; i < num_chars; i += 2) { - if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { - found = true; - break; + do { + bool matches_range; + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + matches_range = pos->value <= chr && chr <= pos[1].value; + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + matches_range = pos->value == chr; + pos += 1; } - } + found = found || matches_range; + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + if (!found) { continue; } - // advance past char range, updating top of stack to next element, if any - pos += num_chars; - std::vector new_stack(stack.begin(), stack.end() - 1); - if (*pos) { + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } llama_grammar_advance_stack(rules, new_stack, new_stacks); @@ -1930,8 +1975,8 @@ static std::vector> llama_grammar_accept( // returns `true` if one of the pushdown stacks can accept the given char. static bool llama_grammar_peek( - const std::vector> & stacks, - const uint16_t chr) { + const std::vector> & stacks, + const uint32_t chr) { for (const auto & stack : stacks) { if (stack.empty()) { @@ -1939,16 +1984,24 @@ static bool llama_grammar_peek( return true; } } else { - const uint16_t * pos = stack.back(); - const uint16_t num_chars = *pos; - LLAMA_ASSERT(num_chars > 1); + const llama_grammar_element * pos = stack.back(); + LLAMA_ASSERT(pos->type == LLAMA_GRETYPE_CHAR); - pos++; - for (int i = 0; i < num_chars; i += 2) { - if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { - return true; + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= chr && chr <= pos[1].value) { + return true; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (pos->value == chr) { + return true; + } + pos += 1; } - } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); } } return false; @@ -1959,45 +2012,44 @@ static bool llama_grammar_peek( // grammar - external // -struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) { - const uint16_t * pos = src; - std::vector rules; +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; - // build `rules` as list of pointers to rules embedded in binary grammar `src` - while (*pos != 0xffff) { - uint16_t rule_id = *pos; - if (rules.size() <= rule_id) { - rules.resize(rule_id + 1); + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); } - rules[rule_id] = pos; - // skip rule id - pos++; - // skip rule alternates - while (*pos) { - pos += 1 + *pos; - } - // skip 0 denoting end of rule - pos++; + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); } - const uint16_t * start_rule = rules[start_rule_id]; - - LLAMA_ASSERT(*start_rule == start_rule_id); - // loop over alternates of start rule to build initial stacks - pos = start_rule + 1; - std::vector> stacks; - while (*pos) { - std::vector stack; - if (pos[1]) { - // if alernate is nonempty, add to stack - stack.push_back(pos + 1); + std::vector> stacks; + pos = rules[start_rule_index]; + do { + std::vector stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); } - llama_grammar_advance_stack(rules, stack, stacks); - pos += 1 + *pos; - } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); - return new llama_grammar{ rules, stacks }; + return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; } void llama_grammar_free(struct llama_grammar * grammar) { @@ -2285,7 +2337,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c const int64_t t_start_sample_us = ggml_time_us(); const llama_token eos = llama_token_eos(); // since many llama tokens are prefixed with a single space, special case a lookahead on ' ' - const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, ' '); + const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, U' '); for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; @@ -2296,10 +2348,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c bool valid = false; if (id == eos) { valid = llama_grammar_peek(grammar->stacks, 0); - } else if (str[0] == ' ') { - valid = llama_grammar_peek(stacks_after_space, str[1]); - } else if (str[0] != 0) { - valid = llama_grammar_peek(grammar->stacks, str[0]); + } else { + const auto decoded = decode_utf8(str); + const uint32_t chr = decoded.first; + if (chr == U' ') { + const char * next = decoded.second; + valid = llama_grammar_peek(stacks_after_space, decode_utf8(next).first); + } else if (chr != 0) { + valid = llama_grammar_peek(grammar->stacks, chr); + } } if (!valid) { @@ -2451,13 +2508,15 @@ llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_ const char * suffix = str; // Find prefix of selected token that matches grammar, expecting at least 1 char - auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *suffix); + auto decoded = decode_utf8(suffix); + auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); LLAMA_ASSERT(!new_stacks.empty()); if (*suffix) { - ++suffix; - for ( ; *suffix; ++suffix) { - new_stacks = llama_grammar_accept(grammar->rules, new_stacks, *suffix); - if (new_stacks.empty()) { + suffix = decoded.second; + for ( ; *suffix; suffix = decoded.second) { + decoded = decode_utf8(suffix); + new_stacks = llama_grammar_accept(grammar->rules, new_stacks, decoded.first); + if (new_stacks.empty() ) { break; } } @@ -2480,8 +2539,9 @@ llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_ // accept the first token of the matching prefix into the grammar llama_token first_prefix_token = tokens[0]; const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token); - for ( ; *first_prefix_str; ++first_prefix_str) { - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *first_prefix_str); + for ( ; *first_prefix_str; first_prefix_str = decoded.second) { + decoded = decode_utf8(first_prefix_str); + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, decoded.first); LLAMA_ASSERT(!grammar->stacks.empty()); } diff --git a/llama.h b/llama.h index c7f2c841e..e827fa33c 100644 --- a/llama.h +++ b/llama.h @@ -55,8 +55,6 @@ extern "C" { struct llama_context; - struct llama_grammar; - typedef int llama_token; typedef struct llama_token_data { @@ -125,6 +123,37 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight } llama_model_quantize_params; + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 5, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + LLAMA_API struct llama_context_params llama_context_default_params(); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); @@ -243,26 +272,11 @@ extern "C" { // Grammar // - // Accepts a binary encoding of a context-free grammar. The returned struct can be used to - // constrain sampled tokens (see below). - // - // The binary format represents one or more production rules, each with one or more alternate - // defininitions: - // - // ( ( )+ 0000)+ FFFF - // - // rule_ids should be assigned sequentially from zero but may appear out of order. Each - // rule alternate is a sequence of zero or more symbols, each prefixed with size: - // - // ( )* 0000 - // - // A symbol of size 1 is interpreted as a rule reference (whose value is the single following - // u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to - // match. Note that symbol sizes greater than 7FFF are reserved for future use. - // - // The provided `src` must be kept valid for the lifetime of the `llama_grammar`. - // - LLAMA_API struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id); + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); // Sampling functions