diff --git a/Makefile b/Makefile index 39265164b..71b1baecf 100644 --- a/Makefile +++ b/Makefile @@ -250,6 +250,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ +grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) @@ -260,7 +263,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3deff4077..bd043ed68 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,6 +13,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + grammar-parser.h + grammar-parser.cpp ) if (BUILD_SHARED_LIBS) diff --git a/examples/common.cpp b/examples/common.cpp index f5d886acf..b20a68260 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -388,6 +388,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.grammar = argv[i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); @@ -458,6 +464,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + fprintf(stderr, " --grammar GRAMMAR BNF-like grammar (TODO explain) to constrain generations\n"); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); diff --git a/examples/common.h b/examples/common.h index 826e2ae59..5eb611841 100644 --- a/examples/common.h +++ b/examples/common.h @@ -52,6 +52,7 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp new file mode 100644 index 000000000..53f9b26d5 --- /dev/null +++ b/examples/grammar-parser.cpp @@ -0,0 +1,315 @@ +#include "grammar-parser.h" +#include +#include +#include +#include + +namespace grammar_parser { + uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint16_t next_id = static_cast(state.symbol_ids.size()); + auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); + return result.first->second; + } + + uint16_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint16_t next_id = static_cast(state.symbol_ids.size()); + state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; + } + + bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + int hex_to_int(char c) { + if ('a' <= c && c <= 'f') { + return c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + return c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + return c - '0'; + } + return -1; + } + + const char * parse_space(const char * src) { + const char * pos = src; + // TODO: support newlines in some cases + while (*pos == ' ' || *pos == '\t') { + pos++; + } + return pos; + } + + std::pair parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::string("expecting name at ") + src; + } + return std::make_pair(pos, parse_space(pos)); + } + + std::pair parse_char(const char * src) { + if (*src == '\\') { + char esc = src[1]; + if (esc == 'x') { + int first = hex_to_int(src[2]); + if (first > -1) { + int second = hex_to_int(src[3]); + if (second > -1) { + return std::make_pair((first << 4) + second, src + 4); + } + } + throw std::string("expecting \\xNN at ") + src; + } else if (esc == '"' || esc == '[' || esc == ']') { + return std::make_pair(esc, src + 2); + } else if (esc == 'r') { + return std::make_pair('\r', src + 2); + } else if (esc == 'n') { + return std::make_pair('\n', src + 2); + } else if (esc == 't') { + return std::make_pair('\t', src + 2); + } + throw std::string("unknown escape at ") + src; + } else if (*src) { + return std::make_pair(*src, src + 1); + } + throw std::string("unexpected end of input"); + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint16_t rule_id); + + const char * parse_sequence( + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & outbuf) { + size_t out_start = outbuf.size(); + + // sequence size, will be replaced at end when known + outbuf.push_back(0); + + size_t last_sym_start = outbuf.size(); + const char * pos = src; + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = outbuf.size(); + while (*pos != '"') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + + // each char of a literal is encoded as a "range" of char - char + outbuf.push_back(2); + outbuf.push_back(char_pair.first); + outbuf.push_back(char_pair.first); + } + pos = parse_space(pos + 1); + } else if (*pos == '[') { // char range(s) + pos++; + last_sym_start = outbuf.size(); + // num chars in range - replaced at end of loop + outbuf.push_back(0); + while (*pos != ']') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + + outbuf.push_back(char_pair.first); + if (pos[0] == '-' && pos[1] != ']') { + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + outbuf.push_back(endchar_pair.first); + } else { + // chars that aren't part of a c1-c2 range are just doubled (i.e., c-c) + outbuf.push_back(char_pair.first); + } + } + // replace num chars with actual + outbuf[last_sym_start] = static_cast(outbuf.size() - last_sym_start - 1); + pos = parse_space(pos + 1); + } else if (is_word_char(*pos)) { // rule reference + auto name_pair = parse_name(pos); + uint16_t ref_rule_id = get_symbol_id(state, pos, name_pair.first - pos); + pos = name_pair.second; + last_sym_start = outbuf.size(); + outbuf.push_back(1); + outbuf.push_back(ref_rule_id); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1); + uint16_t sub_rule_id = generate_symbol_id(state, rule_name); + pos = parse_alternates(state, pos, rule_name, sub_rule_id); + last_sym_start = outbuf.size(); + // output reference to synthesized rule + outbuf.push_back(1); + outbuf.push_back(sub_rule_id); + if (*pos != ')') { + throw std::string("expecting ')' at ") + pos; + } + pos = parse_space(pos + 1); + } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + if (outbuf.size() - out_start - 1 == 0) { + throw std::string("expecting preceeding item to */+/? at ") + pos; + } + std::vector & out_grammar = state.out_grammar; + + // apply transformation to previous symbol (last_sym_start - + // end) according to rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint16_t sub_rule_id = generate_symbol_id(state, rule_name); + out_grammar.push_back(sub_rule_id); + size_t sub_rule_start = out_grammar.size(); + // placeholder for size of 1st alternate + out_grammar.push_back(0); + // add preceding symbol to generated rule + out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + if (*pos == '*' || *pos == '+') { + // cause generated rule to recurse + out_grammar.push_back(1); + out_grammar.push_back(sub_rule_id); + } + // apply actual size + out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; + // mark end of 1st alternate + out_grammar.push_back(0); + sub_rule_start = out_grammar.size(); + // placeholder for size of 2nd alternate + out_grammar.push_back(0); + if (*pos == '+') { + // add preceding symbol as alternate only for '+' + out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end()); + } + // apply actual size of 2nd alternate + out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start; + // mark end of 2nd alternate, then end of rule + out_grammar.push_back(0); + out_grammar.push_back(0); + + // in original rule, replace previous symbol with reference to generated rule + outbuf.resize(last_sym_start); + outbuf.push_back(1); + outbuf.push_back(sub_rule_id); + + pos = parse_space(pos + 1); + } else { + break; + } + } + // apply actual size of this alternate sequence + outbuf[out_start] = static_cast(outbuf.size() - out_start); + // mark end of alternate + outbuf.push_back(0); + return pos; + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint16_t rule_id) { + std::vector outbuf; + const char * pos = parse_sequence(state, src, rule_name, outbuf); + while (*pos == '|') { + pos = parse_space(pos + 1); + pos = parse_sequence(state, pos, rule_name, outbuf); + } + state.out_grammar.push_back(rule_id); + state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end()); + state.out_grammar.push_back(0); + return pos; + } + + const char * parse_rule(parse_state & state, const char * src) { + auto name_pair = parse_name(src); + const char * pos = name_pair.second; + size_t name_len = name_pair.first - src; + uint16_t rule_id = get_symbol_id(state, src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::string("expecting ::= at ") + pos; + } + pos = parse_space(pos + 3); + + pos = parse_alternates(state, pos, name, rule_id); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::string("expecting newline or end at ") + pos; + } + return parse_space(pos); + } + + parse_state parse(const char * src) { + parse_state state; + const char * pos = parse_space(src); + while (*pos) { + pos = parse_rule(state, pos); + } + state.out_grammar.push_back(0xffff); + return state; + } + + const uint16_t * print_rule( + FILE * file, + const uint16_t * base, + const uint16_t * src, + const std::map & symbol_id_names) { + uint16_t rule_id = *src; + fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str()); + const uint16_t * pos = src + 1; + while (*pos) { + if (pos - 1 > src) { + fprintf(file, "| "); + } + pos++; // sequence size, not needed here + while (*pos) { + if (*pos == 1) { + uint16_t ref_rule_id = pos[1]; + fprintf(file, "<%zu>%s ", pos - base, symbol_id_names.at(ref_rule_id).c_str()); + pos += 2; + } else { + fprintf(file, "<%zu>[", pos - base); + uint16_t num_chars = *pos; + pos++; + + for (uint16_t i = 0; i < num_chars; i += 2) { + fprintf(file, "%lc-", static_cast(pos[i])); // REVIEW + if (i + 1 < num_chars) { + fprintf(file, "%lc", static_cast(pos[i + 1])); + } + } + fprintf(file, "] "); + pos += num_chars; + } + } + pos++; + } + fprintf(file, "\n"); + return pos + 1; + } + + void print_grammar(FILE * file, const parse_state & state) { + std::map symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + const uint16_t * pos = state.out_grammar.data(); + while (*pos != 0xffff) { + pos = print_rule(file, state.out_grammar.data(), pos, symbol_id_names); + } + } +} + diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h new file mode 100644 index 000000000..c9e27d4cd --- /dev/null +++ b/examples/grammar-parser.h @@ -0,0 +1,26 @@ +// Implements a parser for an extended Backus-Naur form (BNF), producing the +// binary context-free grammar format specified by llama.h. Supports character +// ranges, grouping, and repetition operators. As an example, a grammar for +// arithmetic might look like: +// +// root ::= expr +// expr ::= term ([-+*/] term)* +// term ::= num | "(" space expr ")" space +// num ::= [0-9]+ space +// space ::= [ \t\n]* + +#pragma once +#include +#include +#include +#include + +namespace grammar_parser { + struct parse_state { + std::map symbol_ids; + std::vector out_grammar; + }; + + parse_state parse(const char * src); + void print_grammar(FILE * file, const parse_state & state); +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index de63faa3e..f43eb4fe4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include #include @@ -291,6 +292,17 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + grammar = llama_grammar_init( + parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -454,6 +466,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -479,6 +495,10 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + id = llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -609,6 +629,17 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + if (grammar != NULL) { + llama_grammar_free(grammar); + } + grammar = llama_grammar_init( + parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } @@ -638,5 +669,9 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); + if (grammar != NULL) { + llama_grammar_free(grammar); + } + return 0; } diff --git a/llama.cpp b/llama.cpp index 16d6f6ef1..877a6a2a1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1821,6 +1821,168 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } +// +// grammar - internal +// + +struct llama_grammar { + const std::vector rules; + std::vector> stacks; +}; + +// transforms a grammar pushdown stack into N possible stacks, all terminating +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const uint16_t * pos = stack.back(); + + if (*pos == 1) { + // rule reference, apply rule to stack + const uint16_t * subpos = rules[pos[1]] + 1; + while (*subpos) { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (pos[2]) { + // if the rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 2); + } + if (subpos[1]) { + // if the referenced rule is nonempty, add that to the stack + new_stack.push_back(subpos + 1); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + subpos += 1 + *subpos; + } + } else { + // rule element size > 1 -> character reference + LLAMA_ASSERT(*pos); + new_stacks.push_back(stack); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> llama_grammar_accept( + const std::vector & rules, + const std::vector> & stacks, + const uint16_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + const uint16_t * pos = stack.back(); + const uint16_t num_chars = *pos; + LLAMA_ASSERT(num_chars > 1); + + pos++; // skip num chars indicator + bool found = false; + // loop over the inclusive char pairs to find a match on the given char + for (int i = 0; i < num_chars; i += 2) { + if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { + found = true; + break; + } + } + if (!found) { + continue; + } + + // advance past char range, updating top of stack to next element, if any + pos += num_chars; + std::vector new_stack(stack.begin(), stack.end() - 1); + if (*pos) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + + return new_stacks; +} + +// returns `true` if one of the pushdown stacks can accept the given char. +static bool llama_grammar_peek( + const std::vector> & stacks, + const uint16_t chr) { + + for (const auto & stack : stacks) { + if (stack.empty()) { + if (!chr) { + return true; + } + } else { + const uint16_t * pos = stack.back(); + const uint16_t num_chars = *pos; + LLAMA_ASSERT(num_chars > 1); + + pos++; + for (int i = 0; i < num_chars; i += 2) { + if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) { + return true; + } + } + } + } + return false; +} + + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) { + const uint16_t * pos = src; + std::vector rules; + + // build `rules` as list of pointers to rules embedded in binary grammar `src` + while (*pos != 0xffff) { + uint16_t rule_id = *pos; + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = pos; + // skip rule id + pos++; + // skip rule alternates + while (*pos) { + pos += 1 + *pos; + } + // skip 0 denoting end of rule + pos++; + } + + // TODO: handle if start rule has alternates + const uint16_t * start_rule = rules[start_rule_id]; + + // rule starts with rule id and 1st alternate's size; skip that so initial + // stack starts at 1st element in 1st alternate + LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]); + const std::vector stack = { start_rule + 2 }; + + std::vector> stacks; + llama_grammar_advance_stack(rules, stack, stacks); + + return new llama_grammar{ rules, stacks }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + // // sampling // @@ -2097,6 +2259,30 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + const llama_token eos = llama_token_eos(); + // since many llama tokens are prefixed with a single space, special case a lookahead on ' ' + const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, ' '); + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); + + // prune tokens based on first char only - in `llama_grammar_accept_token` we will find the + // full matching prefix of the selected token + const bool valid = str[0] == ' ' + ? llama_grammar_peek(stacks_after_space, str[1]) + : llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]); + + if (!valid) { + candidates->data[i].logit = -INFINITY; + } + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { assert(ctx); @@ -2223,6 +2409,60 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos()) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return token; + } + LLAMA_ASSERT(false); + } + } + + const char * str = llama_token_to_str(ctx, token); + const char * suffix = str; + + // Find prefix of selected token that matches grammar, expecting at least 1 char + auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *suffix); + LLAMA_ASSERT(!new_stacks.empty()); + if (*suffix) { + ++suffix; + for ( ; *suffix; ++suffix) { + new_stacks = llama_grammar_accept(grammar->rules, new_stacks, *suffix); + if (new_stacks.empty()) { + break; + } + } + } + + // if full token is matched, accept new stacks + if (!(*suffix)) { + grammar->stacks = new_stacks; + return token; + } + + // otherwise, tokenize the string prefix that did match + llama_token tokens[32]; // TODO - determine actual max token size + const std::string prefix_str(str, suffix - str); + int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false); + if (n_tokens < 1) { + return token; // REVIEW + } + + // accept the first token of the matching prefix into the grammar + llama_token first_prefix_token = tokens[0]; + const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token); + for ( ; *first_prefix_str; ++first_prefix_str) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *first_prefix_str); + LLAMA_ASSERT(!grammar->stacks.empty()); + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + return first_prefix_token; +} + // // quantization // diff --git a/llama.h b/llama.h index dc033b71d..49e26b1b8 100644 --- a/llama.h +++ b/llama.h @@ -55,6 +55,8 @@ extern "C" { struct llama_context; + struct llama_grammar; + typedef int llama_token; typedef struct llama_token_data { @@ -233,6 +235,30 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); LLAMA_API llama_token llama_token_nl(); + // Grammar + // + // Accepts a binary encoding of a context-free grammar. The returned struct can be used to + // constrain sampled tokens (see below). + // + // The binary format represents one or more production rules, each with one or more alternate + // defininitions: + // + // ( ( )+ 0000)+ FFFF + // + // rule_ids should be assigned sequentially from zero but may appear out of order. Each + // rule alternate is a sequence of zero or more symbols, each prefixed with size: + // + // ( )* 0000 + // + // A symbol of size 1 is interpreted as a rule reference (whose value is the single following + // u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to + // match. Note that symbol sizes greater than 7FFF are reserved for future use. + // + // The provided `src` must be kept valid for the lifetime of the `llama_grammar`. + // + LLAMA_API struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id); + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -257,6 +283,9 @@ extern "C" { LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -278,6 +307,9 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Accepts the sampled token into the grammar, possibly transforming to a new token + LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + // Performance information LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_reset_timings(struct llama_context * ctx);