llama, main : constrain sampling to grammar

This commit is contained in:
Evan Jones 2023-05-31 00:20:51 -04:00
parent 72ff5282bf
commit fd0eb663ce
9 changed files with 662 additions and 1 deletions

View file

@ -250,6 +250,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h
common.o: examples/common.cpp examples/common.h
$(CXX) $(CXXFLAGS) -c $< -o $@
grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h
$(CXX) $(CXXFLAGS) -c $< -o $@
libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
@ -260,7 +263,7 @@ clean:
# Examples
#
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='

View file

@ -13,6 +13,8 @@ set(TARGET common)
add_library(${TARGET} OBJECT
common.h
common.cpp
grammar-parser.h
grammar-parser.cpp
)
if (BUILD_SHARED_LIBS)

View file

@ -388,6 +388,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.input_suffix = argv[i];
} else if (arg == "--grammar") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.grammar = argv[i];
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params);
@ -458,6 +464,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " --grammar GRAMMAR BNF-like grammar (TODO explain) to constrain generations\n");
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");

View file

@ -52,6 +52,7 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path

315
examples/grammar-parser.cpp Normal file
View file

@ -0,0 +1,315 @@
#include "grammar-parser.h"
#include <cstdint>
#include <cwchar>
#include <string>
#include <utility>
namespace grammar_parser {
uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) {
uint16_t next_id = static_cast<uint16_t>(state.symbol_ids.size());
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
return result.first->second;
}
uint16_t generate_symbol_id(parse_state & state, const std::string & base_name) {
uint16_t next_id = static_cast<uint16_t>(state.symbol_ids.size());
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
return next_id;
}
bool is_word_char(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
int hex_to_int(char c) {
if ('a' <= c && c <= 'f') {
return c - 'a' + 10;
} else if ('A' <= c && c <= 'F') {
return c - 'A' + 10;
} else if ('0' <= c && c <= '9') {
return c - '0';
}
return -1;
}
const char * parse_space(const char * src) {
const char * pos = src;
// TODO: support newlines in some cases
while (*pos == ' ' || *pos == '\t') {
pos++;
}
return pos;
}
std::pair<const char *, const char *> parse_name(const char * src) {
const char * pos = src;
while (is_word_char(*pos)) {
pos++;
}
if (pos == src) {
throw std::string("expecting name at ") + src;
}
return std::make_pair(pos, parse_space(pos));
}
std::pair<uint16_t, const char *> parse_char(const char * src) {
if (*src == '\\') {
char esc = src[1];
if (esc == 'x') {
int first = hex_to_int(src[2]);
if (first > -1) {
int second = hex_to_int(src[3]);
if (second > -1) {
return std::make_pair((first << 4) + second, src + 4);
}
}
throw std::string("expecting \\xNN at ") + src;
} else if (esc == '"' || esc == '[' || esc == ']') {
return std::make_pair(esc, src + 2);
} else if (esc == 'r') {
return std::make_pair('\r', src + 2);
} else if (esc == 'n') {
return std::make_pair('\n', src + 2);
} else if (esc == 't') {
return std::make_pair('\t', src + 2);
}
throw std::string("unknown escape at ") + src;
} else if (*src) {
return std::make_pair(*src, src + 1);
}
throw std::string("unexpected end of input");
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint16_t rule_id);
const char * parse_sequence(
parse_state & state,
const char * src,
const std::string & rule_name,
std::vector<uint16_t> & outbuf) {
size_t out_start = outbuf.size();
// sequence size, will be replaced at end when known
outbuf.push_back(0);
size_t last_sym_start = outbuf.size();
const char * pos = src;
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = outbuf.size();
while (*pos != '"') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
// each char of a literal is encoded as a "range" of char - char
outbuf.push_back(2);
outbuf.push_back(char_pair.first);
outbuf.push_back(char_pair.first);
}
pos = parse_space(pos + 1);
} else if (*pos == '[') { // char range(s)
pos++;
last_sym_start = outbuf.size();
// num chars in range - replaced at end of loop
outbuf.push_back(0);
while (*pos != ']') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
outbuf.push_back(char_pair.first);
if (pos[0] == '-' && pos[1] != ']') {
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
outbuf.push_back(endchar_pair.first);
} else {
// chars that aren't part of a c1-c2 range are just doubled (i.e., c-c)
outbuf.push_back(char_pair.first);
}
}
// replace num chars with actual
outbuf[last_sym_start] = static_cast<uint16_t>(outbuf.size() - last_sym_start - 1);
pos = parse_space(pos + 1);
} else if (is_word_char(*pos)) { // rule reference
auto name_pair = parse_name(pos);
uint16_t ref_rule_id = get_symbol_id(state, pos, name_pair.first - pos);
pos = name_pair.second;
last_sym_start = outbuf.size();
outbuf.push_back(1);
outbuf.push_back(ref_rule_id);
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1);
uint16_t sub_rule_id = generate_symbol_id(state, rule_name);
pos = parse_alternates(state, pos, rule_name, sub_rule_id);
last_sym_start = outbuf.size();
// output reference to synthesized rule
outbuf.push_back(1);
outbuf.push_back(sub_rule_id);
if (*pos != ')') {
throw std::string("expecting ')' at ") + pos;
}
pos = parse_space(pos + 1);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
if (outbuf.size() - out_start - 1 == 0) {
throw std::string("expecting preceeding item to */+/? at ") + pos;
}
std::vector<uint16_t> & out_grammar = state.out_grammar;
// apply transformation to previous symbol (last_sym_start -
// end) according to rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint16_t sub_rule_id = generate_symbol_id(state, rule_name);
out_grammar.push_back(sub_rule_id);
size_t sub_rule_start = out_grammar.size();
// placeholder for size of 1st alternate
out_grammar.push_back(0);
// add preceding symbol to generated rule
out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end());
if (*pos == '*' || *pos == '+') {
// cause generated rule to recurse
out_grammar.push_back(1);
out_grammar.push_back(sub_rule_id);
}
// apply actual size
out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start;
// mark end of 1st alternate
out_grammar.push_back(0);
sub_rule_start = out_grammar.size();
// placeholder for size of 2nd alternate
out_grammar.push_back(0);
if (*pos == '+') {
// add preceding symbol as alternate only for '+'
out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end());
}
// apply actual size of 2nd alternate
out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start;
// mark end of 2nd alternate, then end of rule
out_grammar.push_back(0);
out_grammar.push_back(0);
// in original rule, replace previous symbol with reference to generated rule
outbuf.resize(last_sym_start);
outbuf.push_back(1);
outbuf.push_back(sub_rule_id);
pos = parse_space(pos + 1);
} else {
break;
}
}
// apply actual size of this alternate sequence
outbuf[out_start] = static_cast<uint16_t>(outbuf.size() - out_start);
// mark end of alternate
outbuf.push_back(0);
return pos;
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint16_t rule_id) {
std::vector<uint16_t> outbuf;
const char * pos = parse_sequence(state, src, rule_name, outbuf);
while (*pos == '|') {
pos = parse_space(pos + 1);
pos = parse_sequence(state, pos, rule_name, outbuf);
}
state.out_grammar.push_back(rule_id);
state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end());
state.out_grammar.push_back(0);
return pos;
}
const char * parse_rule(parse_state & state, const char * src) {
auto name_pair = parse_name(src);
const char * pos = name_pair.second;
size_t name_len = name_pair.first - src;
uint16_t rule_id = get_symbol_id(state, src, name_len);
const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::string("expecting ::= at ") + pos;
}
pos = parse_space(pos + 3);
pos = parse_alternates(state, pos, name, rule_id);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::string("expecting newline or end at ") + pos;
}
return parse_space(pos);
}
parse_state parse(const char * src) {
parse_state state;
const char * pos = parse_space(src);
while (*pos) {
pos = parse_rule(state, pos);
}
state.out_grammar.push_back(0xffff);
return state;
}
const uint16_t * print_rule(
FILE * file,
const uint16_t * base,
const uint16_t * src,
const std::map<uint16_t, std::string> & symbol_id_names) {
uint16_t rule_id = *src;
fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str());
const uint16_t * pos = src + 1;
while (*pos) {
if (pos - 1 > src) {
fprintf(file, "| ");
}
pos++; // sequence size, not needed here
while (*pos) {
if (*pos == 1) {
uint16_t ref_rule_id = pos[1];
fprintf(file, "<%zu>%s ", pos - base, symbol_id_names.at(ref_rule_id).c_str());
pos += 2;
} else {
fprintf(file, "<%zu>[", pos - base);
uint16_t num_chars = *pos;
pos++;
for (uint16_t i = 0; i < num_chars; i += 2) {
fprintf(file, "%lc-", static_cast<wint_t>(pos[i])); // REVIEW
if (i + 1 < num_chars) {
fprintf(file, "%lc", static_cast<wint_t>(pos[i + 1]));
}
}
fprintf(file, "] ");
pos += num_chars;
}
}
pos++;
}
fprintf(file, "\n");
return pos + 1;
}
void print_grammar(FILE * file, const parse_state & state) {
std::map<uint16_t, std::string> symbol_id_names;
for (auto kv : state.symbol_ids) {
symbol_id_names[kv.second] = kv.first;
}
const uint16_t * pos = state.out_grammar.data();
while (*pos != 0xffff) {
pos = print_rule(file, state.out_grammar.data(), pos, symbol_id_names);
}
}
}

26
examples/grammar-parser.h Normal file
View file

@ -0,0 +1,26 @@
// Implements a parser for an extended Backus-Naur form (BNF), producing the
// binary context-free grammar format specified by llama.h. Supports character
// ranges, grouping, and repetition operators. As an example, a grammar for
// arithmetic might look like:
//
// root ::= expr
// expr ::= term ([-+*/] term)*
// term ::= num | "(" space expr ")" space
// num ::= [0-9]+ space
// space ::= [ \t\n]*
#pragma once
#include <vector>
#include <map>
#include <cstdint>
#include <string>
namespace grammar_parser {
struct parse_state {
std::map<std::string, uint16_t> symbol_ids;
std::vector<uint16_t> out_grammar;
};
parse_state parse(const char * src);
void print_grammar(FILE * file, const parse_state & state);
}

View file

@ -6,6 +6,7 @@
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
@ -291,6 +292,17 @@ int main(int argc, char ** argv) {
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
grammar_parser::parse_state parsed_grammar;
llama_grammar * grammar = NULL;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
fprintf(stderr, "\n");
grammar = llama_grammar_init(
parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root"));
}
// TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@ -454,6 +466,10 @@ int main(int argc, char ** argv) {
logits[llama_token_nl()] = nl_logit;
}
if (grammar != NULL) {
llama_sample_grammar(ctx, &candidates_p, grammar);
}
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
@ -479,6 +495,10 @@ int main(int argc, char ** argv) {
}
// printf("`%d`", candidates_p.size);
if (grammar != NULL) {
id = llama_grammar_accept_token(ctx, grammar, id);
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
}
@ -609,6 +629,17 @@ int main(int argc, char ** argv) {
}
if (n_past > 0) {
if (is_interacting) {
// reset grammar state if we're restarting generation
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
if (grammar != NULL) {
llama_grammar_free(grammar);
}
grammar = llama_grammar_init(
parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root"));
}
}
is_interacting = false;
}
}
@ -638,5 +669,9 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
if (grammar != NULL) {
llama_grammar_free(grammar);
}
return 0;
}

240
llama.cpp
View file

@ -1821,6 +1821,168 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output;
}
//
// grammar - internal
//
struct llama_grammar {
const std::vector<const uint16_t *> rules;
std::vector<std::vector<const uint16_t *>> stacks;
};
// transforms a grammar pushdown stack into N possible stacks, all terminating
// at a character range (terminal element)
static void llama_grammar_advance_stack(
const std::vector<const uint16_t *> & rules,
const std::vector<const uint16_t *> & stack,
std::vector<std::vector<const uint16_t *>> & new_stacks) {
if (stack.empty()) {
new_stacks.push_back(stack);
return;
}
const uint16_t * pos = stack.back();
if (*pos == 1) {
// rule reference, apply rule to stack
const uint16_t * subpos = rules[pos[1]] + 1;
while (*subpos) {
// init new stack without the top (pos)
std::vector<const uint16_t *> new_stack(stack.begin(), stack.end() - 1);
if (pos[2]) {
// if the rule ref is followed by another element, add that to stack
new_stack.push_back(pos + 2);
}
if (subpos[1]) {
// if the referenced rule is nonempty, add that to the stack
new_stack.push_back(subpos + 1);
}
llama_grammar_advance_stack(rules, new_stack, new_stacks);
subpos += 1 + *subpos;
}
} else {
// rule element size > 1 -> character reference
LLAMA_ASSERT(*pos);
new_stacks.push_back(stack);
}
}
// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
static std::vector<std::vector<const uint16_t *>> llama_grammar_accept(
const std::vector<const uint16_t *> & rules,
const std::vector<std::vector<const uint16_t *>> & stacks,
const uint16_t chr) {
std::vector<std::vector<const uint16_t *>> new_stacks;
for (const auto & stack : stacks) {
if (stack.empty()) {
continue;
}
const uint16_t * pos = stack.back();
const uint16_t num_chars = *pos;
LLAMA_ASSERT(num_chars > 1);
pos++; // skip num chars indicator
bool found = false;
// loop over the inclusive char pairs to find a match on the given char
for (int i = 0; i < num_chars; i += 2) {
if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) {
found = true;
break;
}
}
if (!found) {
continue;
}
// advance past char range, updating top of stack to next element, if any
pos += num_chars;
std::vector<const uint16_t *> new_stack(stack.begin(), stack.end() - 1);
if (*pos) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
return new_stacks;
}
// returns `true` if one of the pushdown stacks can accept the given char.
static bool llama_grammar_peek(
const std::vector<std::vector<const uint16_t *>> & stacks,
const uint16_t chr) {
for (const auto & stack : stacks) {
if (stack.empty()) {
if (!chr) {
return true;
}
} else {
const uint16_t * pos = stack.back();
const uint16_t num_chars = *pos;
LLAMA_ASSERT(num_chars > 1);
pos++;
for (int i = 0; i < num_chars; i += 2) {
if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) {
return true;
}
}
}
}
return false;
}
//
// grammar - external
//
struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) {
const uint16_t * pos = src;
std::vector<const uint16_t *> rules;
// build `rules` as list of pointers to rules embedded in binary grammar `src`
while (*pos != 0xffff) {
uint16_t rule_id = *pos;
if (rules.size() <= rule_id) {
rules.resize(rule_id + 1);
}
rules[rule_id] = pos;
// skip rule id
pos++;
// skip rule alternates
while (*pos) {
pos += 1 + *pos;
}
// skip 0 denoting end of rule
pos++;
}
// TODO: handle if start rule has alternates
const uint16_t * start_rule = rules[start_rule_id];
// rule starts with rule id and 1st alternate's size; skip that so initial
// stack starts at 1st element in 1st alternate
LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]);
const std::vector<const uint16_t *> stack = { start_rule + 2 };
std::vector<std::vector<const uint16_t *>> stacks;
llama_grammar_advance_stack(rules, stack, stacks);
return new llama_grammar{ rules, stacks };
}
void llama_grammar_free(struct llama_grammar * grammar) {
delete grammar;
}
//
// sampling
//
@ -2097,6 +2259,30 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
assert(ctx);
const int64_t t_start_sample_us = ggml_time_us();
const llama_token eos = llama_token_eos();
// since many llama tokens are prefixed with a single space, special case a lookahead on ' '
const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, ' ');
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const char * str = llama_token_to_str(ctx, id);
// prune tokens based on first char only - in `llama_grammar_accept_token` we will find the
// full matching prefix of the selected token
const bool valid = str[0] == ' '
? llama_grammar_peek(stacks_after_space, str[1])
: llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]);
if (!valid) {
candidates->data[i].logit = -INFINITY;
}
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
assert(ctx);
@ -2223,6 +2409,60 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result;
}
llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos()) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return token;
}
LLAMA_ASSERT(false);
}
}
const char * str = llama_token_to_str(ctx, token);
const char * suffix = str;
// Find prefix of selected token that matches grammar, expecting at least 1 char
auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *suffix);
LLAMA_ASSERT(!new_stacks.empty());
if (*suffix) {
++suffix;
for ( ; *suffix; ++suffix) {
new_stacks = llama_grammar_accept(grammar->rules, new_stacks, *suffix);
if (new_stacks.empty()) {
break;
}
}
}
// if full token is matched, accept new stacks
if (!(*suffix)) {
grammar->stacks = new_stacks;
return token;
}
// otherwise, tokenize the string prefix that did match
llama_token tokens[32]; // TODO - determine actual max token size
const std::string prefix_str(str, suffix - str);
int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false);
if (n_tokens < 1) {
return token; // REVIEW
}
// accept the first token of the matching prefix into the grammar
llama_token first_prefix_token = tokens[0];
const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token);
for ( ; *first_prefix_str; ++first_prefix_str) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *first_prefix_str);
LLAMA_ASSERT(!grammar->stacks.empty());
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return first_prefix_token;
}
//
// quantization
//

32
llama.h
View file

@ -55,6 +55,8 @@ extern "C" {
struct llama_context;
struct llama_grammar;
typedef int llama_token;
typedef struct llama_token_data {
@ -233,6 +235,30 @@ extern "C" {
LLAMA_API llama_token llama_token_eos();
LLAMA_API llama_token llama_token_nl();
// Grammar
//
// Accepts a binary encoding of a context-free grammar. The returned struct can be used to
// constrain sampled tokens (see below).
//
// The binary format represents one or more production rules, each with one or more alternate
// defininitions:
//
// (<rule_id: u16> (<alt_size: u16> <alt_size * u16>)+ 0000)+ FFFF
//
// rule_ids should be assigned sequentially from zero but may appear out of order. Each
// rule alternate is a sequence of zero or more symbols, each prefixed with size:
//
// (<sym_size: u16> <sym_size * u16>)* 0000
//
// A symbol of size 1 is interpreted as a rule reference (whose value is the single following
// u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to
// match. Note that symbol sizes greater than 7FFF are reserved for future use.
//
// The provided `src` must be kept valid for the lifetime of the `llama_grammar`.
//
LLAMA_API struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id);
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
// Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@ -257,6 +283,9 @@ extern "C" {
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@ -278,6 +307,9 @@ extern "C" {
/// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Accepts the sampled token into the grammar, possibly transforming to a new token
LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);