llama, main : constrain sampling to grammar
This commit is contained in:
parent
72ff5282bf
commit
fd0eb663ce
9 changed files with 662 additions and 1 deletions
5
Makefile
5
Makefile
|
@ -250,6 +250,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h llama.h llama-util.h
|
||||||
common.o: examples/common.cpp examples/common.h
|
common.o: examples/common.cpp examples/common.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
libllama.so: llama.o ggml.o $(OBJS)
|
libllama.so: llama.o ggml.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
|
||||||
|
|
||||||
|
@ -260,7 +263,7 @@ clean:
|
||||||
# Examples
|
# Examples
|
||||||
#
|
#
|
||||||
|
|
||||||
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
@echo
|
@echo
|
||||||
@echo '==== Run ./main -h for help. ===='
|
@echo '==== Run ./main -h for help. ===='
|
||||||
|
|
|
@ -13,6 +13,8 @@ set(TARGET common)
|
||||||
add_library(${TARGET} OBJECT
|
add_library(${TARGET} OBJECT
|
||||||
common.h
|
common.h
|
||||||
common.cpp
|
common.cpp
|
||||||
|
grammar-parser.h
|
||||||
|
grammar-parser.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
if (BUILD_SHARED_LIBS)
|
||||||
|
|
|
@ -388,6 +388,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.input_suffix = argv[i];
|
params.input_suffix = argv[i];
|
||||||
|
} else if (arg == "--grammar") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.grammar = argv[i];
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
gpt_print_usage(argc, argv, default_params);
|
gpt_print_usage(argc, argv, default_params);
|
||||||
|
@ -458,6 +464,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
|
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
|
||||||
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
|
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
|
||||||
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
|
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
|
||||||
|
fprintf(stderr, " --grammar GRAMMAR BNF-like grammar (TODO explain) to constrain generations\n");
|
||||||
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
|
||||||
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
|
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
|
||||||
|
|
|
@ -52,6 +52,7 @@ struct gpt_params {
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||||
std::string input_prefix = ""; // string to prefix user inputs with
|
std::string input_prefix = ""; // string to prefix user inputs with
|
||||||
std::string input_suffix = ""; // string to suffix user inputs with
|
std::string input_suffix = ""; // string to suffix user inputs with
|
||||||
|
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
|
|
||||||
std::string lora_adapter = ""; // lora adapter path
|
std::string lora_adapter = ""; // lora adapter path
|
||||||
|
|
315
examples/grammar-parser.cpp
Normal file
315
examples/grammar-parser.cpp
Normal file
|
@ -0,0 +1,315 @@
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cwchar>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace grammar_parser {
|
||||||
|
uint16_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||||
|
uint16_t next_id = static_cast<uint16_t>(state.symbol_ids.size());
|
||||||
|
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||||
|
return result.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||||
|
uint16_t next_id = static_cast<uint16_t>(state.symbol_ids.size());
|
||||||
|
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||||
|
return next_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_word_char(char c) {
|
||||||
|
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||||
|
}
|
||||||
|
|
||||||
|
int hex_to_int(char c) {
|
||||||
|
if ('a' <= c && c <= 'f') {
|
||||||
|
return c - 'a' + 10;
|
||||||
|
} else if ('A' <= c && c <= 'F') {
|
||||||
|
return c - 'A' + 10;
|
||||||
|
} else if ('0' <= c && c <= '9') {
|
||||||
|
return c - '0';
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_space(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
// TODO: support newlines in some cases
|
||||||
|
while (*pos == ' ' || *pos == '\t') {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<const char *, const char *> parse_name(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (is_word_char(*pos)) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos == src) {
|
||||||
|
throw std::string("expecting name at ") + src;
|
||||||
|
}
|
||||||
|
return std::make_pair(pos, parse_space(pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<uint16_t, const char *> parse_char(const char * src) {
|
||||||
|
if (*src == '\\') {
|
||||||
|
char esc = src[1];
|
||||||
|
if (esc == 'x') {
|
||||||
|
int first = hex_to_int(src[2]);
|
||||||
|
if (first > -1) {
|
||||||
|
int second = hex_to_int(src[3]);
|
||||||
|
if (second > -1) {
|
||||||
|
return std::make_pair((first << 4) + second, src + 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw std::string("expecting \\xNN at ") + src;
|
||||||
|
} else if (esc == '"' || esc == '[' || esc == ']') {
|
||||||
|
return std::make_pair(esc, src + 2);
|
||||||
|
} else if (esc == 'r') {
|
||||||
|
return std::make_pair('\r', src + 2);
|
||||||
|
} else if (esc == 'n') {
|
||||||
|
return std::make_pair('\n', src + 2);
|
||||||
|
} else if (esc == 't') {
|
||||||
|
return std::make_pair('\t', src + 2);
|
||||||
|
}
|
||||||
|
throw std::string("unknown escape at ") + src;
|
||||||
|
} else if (*src) {
|
||||||
|
return std::make_pair(*src, src + 1);
|
||||||
|
}
|
||||||
|
throw std::string("unexpected end of input");
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_alternates(
|
||||||
|
parse_state & state,
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint16_t rule_id);
|
||||||
|
|
||||||
|
const char * parse_sequence(
|
||||||
|
parse_state & state,
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
std::vector<uint16_t> & outbuf) {
|
||||||
|
size_t out_start = outbuf.size();
|
||||||
|
|
||||||
|
// sequence size, will be replaced at end when known
|
||||||
|
outbuf.push_back(0);
|
||||||
|
|
||||||
|
size_t last_sym_start = outbuf.size();
|
||||||
|
const char * pos = src;
|
||||||
|
while (*pos) {
|
||||||
|
if (*pos == '"') { // literal string
|
||||||
|
pos++;
|
||||||
|
last_sym_start = outbuf.size();
|
||||||
|
while (*pos != '"') {
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
|
||||||
|
// each char of a literal is encoded as a "range" of char - char
|
||||||
|
outbuf.push_back(2);
|
||||||
|
outbuf.push_back(char_pair.first);
|
||||||
|
outbuf.push_back(char_pair.first);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1);
|
||||||
|
} else if (*pos == '[') { // char range(s)
|
||||||
|
pos++;
|
||||||
|
last_sym_start = outbuf.size();
|
||||||
|
// num chars in range - replaced at end of loop
|
||||||
|
outbuf.push_back(0);
|
||||||
|
while (*pos != ']') {
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
|
||||||
|
outbuf.push_back(char_pair.first);
|
||||||
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
|
auto endchar_pair = parse_char(pos + 1);
|
||||||
|
pos = endchar_pair.second;
|
||||||
|
outbuf.push_back(endchar_pair.first);
|
||||||
|
} else {
|
||||||
|
// chars that aren't part of a c1-c2 range are just doubled (i.e., c-c)
|
||||||
|
outbuf.push_back(char_pair.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// replace num chars with actual
|
||||||
|
outbuf[last_sym_start] = static_cast<uint16_t>(outbuf.size() - last_sym_start - 1);
|
||||||
|
pos = parse_space(pos + 1);
|
||||||
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
|
auto name_pair = parse_name(pos);
|
||||||
|
uint16_t ref_rule_id = get_symbol_id(state, pos, name_pair.first - pos);
|
||||||
|
pos = name_pair.second;
|
||||||
|
last_sym_start = outbuf.size();
|
||||||
|
outbuf.push_back(1);
|
||||||
|
outbuf.push_back(ref_rule_id);
|
||||||
|
} else if (*pos == '(') { // grouping
|
||||||
|
// parse nested alternates into synthesized rule
|
||||||
|
pos = parse_space(pos + 1);
|
||||||
|
uint16_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||||
|
pos = parse_alternates(state, pos, rule_name, sub_rule_id);
|
||||||
|
last_sym_start = outbuf.size();
|
||||||
|
// output reference to synthesized rule
|
||||||
|
outbuf.push_back(1);
|
||||||
|
outbuf.push_back(sub_rule_id);
|
||||||
|
if (*pos != ')') {
|
||||||
|
throw std::string("expecting ')' at ") + pos;
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1);
|
||||||
|
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||||
|
if (outbuf.size() - out_start - 1 == 0) {
|
||||||
|
throw std::string("expecting preceeding item to */+/? at ") + pos;
|
||||||
|
}
|
||||||
|
std::vector<uint16_t> & out_grammar = state.out_grammar;
|
||||||
|
|
||||||
|
// apply transformation to previous symbol (last_sym_start -
|
||||||
|
// end) according to rewrite rules:
|
||||||
|
// S* --> S' ::= S S' |
|
||||||
|
// S+ --> S' ::= S S' | S
|
||||||
|
// S? --> S' ::= S |
|
||||||
|
uint16_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||||
|
out_grammar.push_back(sub_rule_id);
|
||||||
|
size_t sub_rule_start = out_grammar.size();
|
||||||
|
// placeholder for size of 1st alternate
|
||||||
|
out_grammar.push_back(0);
|
||||||
|
// add preceding symbol to generated rule
|
||||||
|
out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end());
|
||||||
|
if (*pos == '*' || *pos == '+') {
|
||||||
|
// cause generated rule to recurse
|
||||||
|
out_grammar.push_back(1);
|
||||||
|
out_grammar.push_back(sub_rule_id);
|
||||||
|
}
|
||||||
|
// apply actual size
|
||||||
|
out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start;
|
||||||
|
// mark end of 1st alternate
|
||||||
|
out_grammar.push_back(0);
|
||||||
|
sub_rule_start = out_grammar.size();
|
||||||
|
// placeholder for size of 2nd alternate
|
||||||
|
out_grammar.push_back(0);
|
||||||
|
if (*pos == '+') {
|
||||||
|
// add preceding symbol as alternate only for '+'
|
||||||
|
out_grammar.insert(out_grammar.end(), outbuf.begin() + last_sym_start, outbuf.end());
|
||||||
|
}
|
||||||
|
// apply actual size of 2nd alternate
|
||||||
|
out_grammar[sub_rule_start] = out_grammar.size() - sub_rule_start;
|
||||||
|
// mark end of 2nd alternate, then end of rule
|
||||||
|
out_grammar.push_back(0);
|
||||||
|
out_grammar.push_back(0);
|
||||||
|
|
||||||
|
// in original rule, replace previous symbol with reference to generated rule
|
||||||
|
outbuf.resize(last_sym_start);
|
||||||
|
outbuf.push_back(1);
|
||||||
|
outbuf.push_back(sub_rule_id);
|
||||||
|
|
||||||
|
pos = parse_space(pos + 1);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// apply actual size of this alternate sequence
|
||||||
|
outbuf[out_start] = static_cast<uint16_t>(outbuf.size() - out_start);
|
||||||
|
// mark end of alternate
|
||||||
|
outbuf.push_back(0);
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_alternates(
|
||||||
|
parse_state & state,
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint16_t rule_id) {
|
||||||
|
std::vector<uint16_t> outbuf;
|
||||||
|
const char * pos = parse_sequence(state, src, rule_name, outbuf);
|
||||||
|
while (*pos == '|') {
|
||||||
|
pos = parse_space(pos + 1);
|
||||||
|
pos = parse_sequence(state, pos, rule_name, outbuf);
|
||||||
|
}
|
||||||
|
state.out_grammar.push_back(rule_id);
|
||||||
|
state.out_grammar.insert(state.out_grammar.end(), outbuf.begin(), outbuf.end());
|
||||||
|
state.out_grammar.push_back(0);
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_rule(parse_state & state, const char * src) {
|
||||||
|
auto name_pair = parse_name(src);
|
||||||
|
const char * pos = name_pair.second;
|
||||||
|
size_t name_len = name_pair.first - src;
|
||||||
|
uint16_t rule_id = get_symbol_id(state, src, name_len);
|
||||||
|
const std::string name(src, name_len);
|
||||||
|
|
||||||
|
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||||
|
throw std::string("expecting ::= at ") + pos;
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 3);
|
||||||
|
|
||||||
|
pos = parse_alternates(state, pos, name, rule_id);
|
||||||
|
|
||||||
|
if (*pos == '\r') {
|
||||||
|
pos += pos[1] == '\n' ? 2 : 1;
|
||||||
|
} else if (*pos == '\n') {
|
||||||
|
pos++;
|
||||||
|
} else if (*pos) {
|
||||||
|
throw std::string("expecting newline or end at ") + pos;
|
||||||
|
}
|
||||||
|
return parse_space(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
parse_state parse(const char * src) {
|
||||||
|
parse_state state;
|
||||||
|
const char * pos = parse_space(src);
|
||||||
|
while (*pos) {
|
||||||
|
pos = parse_rule(state, pos);
|
||||||
|
}
|
||||||
|
state.out_grammar.push_back(0xffff);
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint16_t * print_rule(
|
||||||
|
FILE * file,
|
||||||
|
const uint16_t * base,
|
||||||
|
const uint16_t * src,
|
||||||
|
const std::map<uint16_t, std::string> & symbol_id_names) {
|
||||||
|
uint16_t rule_id = *src;
|
||||||
|
fprintf(file, "<%zu>%s ::= ", src - base, symbol_id_names.at(rule_id).c_str());
|
||||||
|
const uint16_t * pos = src + 1;
|
||||||
|
while (*pos) {
|
||||||
|
if (pos - 1 > src) {
|
||||||
|
fprintf(file, "| ");
|
||||||
|
}
|
||||||
|
pos++; // sequence size, not needed here
|
||||||
|
while (*pos) {
|
||||||
|
if (*pos == 1) {
|
||||||
|
uint16_t ref_rule_id = pos[1];
|
||||||
|
fprintf(file, "<%zu>%s ", pos - base, symbol_id_names.at(ref_rule_id).c_str());
|
||||||
|
pos += 2;
|
||||||
|
} else {
|
||||||
|
fprintf(file, "<%zu>[", pos - base);
|
||||||
|
uint16_t num_chars = *pos;
|
||||||
|
pos++;
|
||||||
|
|
||||||
|
for (uint16_t i = 0; i < num_chars; i += 2) {
|
||||||
|
fprintf(file, "%lc-", static_cast<wint_t>(pos[i])); // REVIEW
|
||||||
|
if (i + 1 < num_chars) {
|
||||||
|
fprintf(file, "%lc", static_cast<wint_t>(pos[i + 1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "] ");
|
||||||
|
pos += num_chars;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
return pos + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_grammar(FILE * file, const parse_state & state) {
|
||||||
|
std::map<uint16_t, std::string> symbol_id_names;
|
||||||
|
for (auto kv : state.symbol_ids) {
|
||||||
|
symbol_id_names[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
const uint16_t * pos = state.out_grammar.data();
|
||||||
|
while (*pos != 0xffff) {
|
||||||
|
pos = print_rule(file, state.out_grammar.data(), pos, symbol_id_names);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
26
examples/grammar-parser.h
Normal file
26
examples/grammar-parser.h
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||||
|
// binary context-free grammar format specified by llama.h. Supports character
|
||||||
|
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||||
|
// arithmetic might look like:
|
||||||
|
//
|
||||||
|
// root ::= expr
|
||||||
|
// expr ::= term ([-+*/] term)*
|
||||||
|
// term ::= num | "(" space expr ")" space
|
||||||
|
// num ::= [0-9]+ space
|
||||||
|
// space ::= [ \t\n]*
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace grammar_parser {
|
||||||
|
struct parse_state {
|
||||||
|
std::map<std::string, uint16_t> symbol_ids;
|
||||||
|
std::vector<uint16_t> out_grammar;
|
||||||
|
};
|
||||||
|
|
||||||
|
parse_state parse(const char * src);
|
||||||
|
void print_grammar(FILE * file, const parse_state & state);
|
||||||
|
}
|
|
@ -6,6 +6,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "build-info.h"
|
#include "build-info.h"
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
@ -291,6 +292,17 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
grammar_parser::parse_state parsed_grammar;
|
||||||
|
llama_grammar * grammar = NULL;
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
|
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||||
|
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
grammar = llama_grammar_init(
|
||||||
|
parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> last_n_tokens(n_ctx);
|
std::vector<llama_token> last_n_tokens(n_ctx);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
@ -454,6 +466,10 @@ int main(int argc, char ** argv) {
|
||||||
logits[llama_token_nl()] = nl_logit;
|
logits[llama_token_nl()] = nl_logit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_sample_grammar(ctx, &candidates_p, grammar);
|
||||||
|
}
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// Greedy sampling
|
||||||
id = llama_sample_token_greedy(ctx, &candidates_p);
|
id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
|
@ -479,6 +495,10 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
// printf("`%d`", candidates_p.size);
|
// printf("`%d`", candidates_p.size);
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
id = llama_grammar_accept_token(ctx, grammar, id);
|
||||||
|
}
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
}
|
}
|
||||||
|
@ -609,6 +629,17 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
|
if (is_interacting) {
|
||||||
|
// reset grammar state if we're restarting generation
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_free(grammar);
|
||||||
|
}
|
||||||
|
grammar = llama_grammar_init(
|
||||||
|
parsed_grammar.out_grammar.data(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -638,5 +669,9 @@ int main(int argc, char ** argv) {
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_free(grammar);
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
240
llama.cpp
240
llama.cpp
|
@ -1821,6 +1821,168 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// grammar - internal
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_grammar {
|
||||||
|
const std::vector<const uint16_t *> rules;
|
||||||
|
std::vector<std::vector<const uint16_t *>> stacks;
|
||||||
|
};
|
||||||
|
|
||||||
|
// transforms a grammar pushdown stack into N possible stacks, all terminating
|
||||||
|
// at a character range (terminal element)
|
||||||
|
static void llama_grammar_advance_stack(
|
||||||
|
const std::vector<const uint16_t *> & rules,
|
||||||
|
const std::vector<const uint16_t *> & stack,
|
||||||
|
std::vector<std::vector<const uint16_t *>> & new_stacks) {
|
||||||
|
|
||||||
|
if (stack.empty()) {
|
||||||
|
new_stacks.push_back(stack);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint16_t * pos = stack.back();
|
||||||
|
|
||||||
|
if (*pos == 1) {
|
||||||
|
// rule reference, apply rule to stack
|
||||||
|
const uint16_t * subpos = rules[pos[1]] + 1;
|
||||||
|
while (*subpos) {
|
||||||
|
// init new stack without the top (pos)
|
||||||
|
std::vector<const uint16_t *> new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (pos[2]) {
|
||||||
|
// if the rule ref is followed by another element, add that to stack
|
||||||
|
new_stack.push_back(pos + 2);
|
||||||
|
}
|
||||||
|
if (subpos[1]) {
|
||||||
|
// if the referenced rule is nonempty, add that to the stack
|
||||||
|
new_stack.push_back(subpos + 1);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
|
subpos += 1 + *subpos;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// rule element size > 1 -> character reference
|
||||||
|
LLAMA_ASSERT(*pos);
|
||||||
|
new_stacks.push_back(stack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
|
// positions
|
||||||
|
static std::vector<std::vector<const uint16_t *>> llama_grammar_accept(
|
||||||
|
const std::vector<const uint16_t *> & rules,
|
||||||
|
const std::vector<std::vector<const uint16_t *>> & stacks,
|
||||||
|
const uint16_t chr) {
|
||||||
|
|
||||||
|
std::vector<std::vector<const uint16_t *>> new_stacks;
|
||||||
|
|
||||||
|
for (const auto & stack : stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint16_t * pos = stack.back();
|
||||||
|
const uint16_t num_chars = *pos;
|
||||||
|
LLAMA_ASSERT(num_chars > 1);
|
||||||
|
|
||||||
|
pos++; // skip num chars indicator
|
||||||
|
bool found = false;
|
||||||
|
// loop over the inclusive char pairs to find a match on the given char
|
||||||
|
for (int i = 0; i < num_chars; i += 2) {
|
||||||
|
if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) {
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// advance past char range, updating top of stack to next element, if any
|
||||||
|
pos += num_chars;
|
||||||
|
std::vector<const uint16_t *> new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (*pos) {
|
||||||
|
new_stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_stacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns `true` if one of the pushdown stacks can accept the given char.
|
||||||
|
static bool llama_grammar_peek(
|
||||||
|
const std::vector<std::vector<const uint16_t *>> & stacks,
|
||||||
|
const uint16_t chr) {
|
||||||
|
|
||||||
|
for (const auto & stack : stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
if (!chr) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const uint16_t * pos = stack.back();
|
||||||
|
const uint16_t num_chars = *pos;
|
||||||
|
LLAMA_ASSERT(num_chars > 1);
|
||||||
|
|
||||||
|
pos++;
|
||||||
|
for (int i = 0; i < num_chars; i += 2) {
|
||||||
|
if (pos[i] <= chr && (i + 1 == num_chars || chr <= pos[i + 1])) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// grammar - external
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id) {
|
||||||
|
const uint16_t * pos = src;
|
||||||
|
std::vector<const uint16_t *> rules;
|
||||||
|
|
||||||
|
// build `rules` as list of pointers to rules embedded in binary grammar `src`
|
||||||
|
while (*pos != 0xffff) {
|
||||||
|
uint16_t rule_id = *pos;
|
||||||
|
if (rules.size() <= rule_id) {
|
||||||
|
rules.resize(rule_id + 1);
|
||||||
|
}
|
||||||
|
rules[rule_id] = pos;
|
||||||
|
// skip rule id
|
||||||
|
pos++;
|
||||||
|
// skip rule alternates
|
||||||
|
while (*pos) {
|
||||||
|
pos += 1 + *pos;
|
||||||
|
}
|
||||||
|
// skip 0 denoting end of rule
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle if start rule has alternates
|
||||||
|
const uint16_t * start_rule = rules[start_rule_id];
|
||||||
|
|
||||||
|
// rule starts with rule id and 1st alternate's size; skip that so initial
|
||||||
|
// stack starts at 1st element in 1st alternate
|
||||||
|
LLAMA_ASSERT(start_rule[0] == start_rule_id && start_rule[1]);
|
||||||
|
const std::vector<const uint16_t *> stack = { start_rule + 2 };
|
||||||
|
|
||||||
|
std::vector<std::vector<const uint16_t *>> stacks;
|
||||||
|
llama_grammar_advance_stack(rules, stack, stacks);
|
||||||
|
|
||||||
|
return new llama_grammar{ rules, stacks };
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_free(struct llama_grammar * grammar) {
|
||||||
|
delete grammar;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// sampling
|
// sampling
|
||||||
//
|
//
|
||||||
|
@ -2097,6 +2259,30 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
|
||||||
|
assert(ctx);
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
const llama_token eos = llama_token_eos();
|
||||||
|
// since many llama tokens are prefixed with a single space, special case a lookahead on ' '
|
||||||
|
const auto stacks_after_space = llama_grammar_accept(grammar->rules, grammar->stacks, ' ');
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
const llama_token id = candidates->data[i].id;
|
||||||
|
const char * str = llama_token_to_str(ctx, id);
|
||||||
|
|
||||||
|
// prune tokens based on first char only - in `llama_grammar_accept_token` we will find the
|
||||||
|
// full matching prefix of the selected token
|
||||||
|
const bool valid = str[0] == ' '
|
||||||
|
? llama_grammar_peek(stacks_after_space, str[1])
|
||||||
|
: llama_grammar_peek(grammar->stacks, id == eos ? 0 : str[0]);
|
||||||
|
|
||||||
|
if (!valid) {
|
||||||
|
candidates->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
|
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
|
||||||
assert(ctx);
|
assert(ctx);
|
||||||
|
@ -2223,6 +2409,60 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
if (token == llama_token_eos()) {
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
LLAMA_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * str = llama_token_to_str(ctx, token);
|
||||||
|
const char * suffix = str;
|
||||||
|
|
||||||
|
// Find prefix of selected token that matches grammar, expecting at least 1 char
|
||||||
|
auto new_stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *suffix);
|
||||||
|
LLAMA_ASSERT(!new_stacks.empty());
|
||||||
|
if (*suffix) {
|
||||||
|
++suffix;
|
||||||
|
for ( ; *suffix; ++suffix) {
|
||||||
|
new_stacks = llama_grammar_accept(grammar->rules, new_stacks, *suffix);
|
||||||
|
if (new_stacks.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if full token is matched, accept new stacks
|
||||||
|
if (!(*suffix)) {
|
||||||
|
grammar->stacks = new_stacks;
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise, tokenize the string prefix that did match
|
||||||
|
llama_token tokens[32]; // TODO - determine actual max token size
|
||||||
|
const std::string prefix_str(str, suffix - str);
|
||||||
|
int n_tokens = llama_tokenize(ctx, prefix_str.c_str(), tokens, 32, false);
|
||||||
|
if (n_tokens < 1) {
|
||||||
|
return token; // REVIEW
|
||||||
|
}
|
||||||
|
|
||||||
|
// accept the first token of the matching prefix into the grammar
|
||||||
|
llama_token first_prefix_token = tokens[0];
|
||||||
|
const char * first_prefix_str = llama_token_to_str(ctx, first_prefix_token);
|
||||||
|
for ( ; *first_prefix_str; ++first_prefix_str) {
|
||||||
|
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *first_prefix_str);
|
||||||
|
LLAMA_ASSERT(!grammar->stacks.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
return first_prefix_token;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// quantization
|
// quantization
|
||||||
//
|
//
|
||||||
|
|
32
llama.h
32
llama.h
|
@ -55,6 +55,8 @@ extern "C" {
|
||||||
|
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
|
||||||
|
struct llama_grammar;
|
||||||
|
|
||||||
typedef int llama_token;
|
typedef int llama_token;
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
|
@ -233,6 +235,30 @@ extern "C" {
|
||||||
LLAMA_API llama_token llama_token_eos();
|
LLAMA_API llama_token llama_token_eos();
|
||||||
LLAMA_API llama_token llama_token_nl();
|
LLAMA_API llama_token llama_token_nl();
|
||||||
|
|
||||||
|
// Grammar
|
||||||
|
//
|
||||||
|
// Accepts a binary encoding of a context-free grammar. The returned struct can be used to
|
||||||
|
// constrain sampled tokens (see below).
|
||||||
|
//
|
||||||
|
// The binary format represents one or more production rules, each with one or more alternate
|
||||||
|
// defininitions:
|
||||||
|
//
|
||||||
|
// (<rule_id: u16> (<alt_size: u16> <alt_size * u16>)+ 0000)+ FFFF
|
||||||
|
//
|
||||||
|
// rule_ids should be assigned sequentially from zero but may appear out of order. Each
|
||||||
|
// rule alternate is a sequence of zero or more symbols, each prefixed with size:
|
||||||
|
//
|
||||||
|
// (<sym_size: u16> <sym_size * u16>)* 0000
|
||||||
|
//
|
||||||
|
// A symbol of size 1 is interpreted as a rule reference (whose value is the single following
|
||||||
|
// u16). Symbols sized greater than 1 are interpreted as inclusive pairs of 16-bit chars to
|
||||||
|
// match. Note that symbol sizes greater than 7FFF are reserved for future use.
|
||||||
|
//
|
||||||
|
// The provided `src` must be kept valid for the lifetime of the `llama_grammar`.
|
||||||
|
//
|
||||||
|
LLAMA_API struct llama_grammar * llama_grammar_init(const uint16_t * src, uint16_t start_rule_id);
|
||||||
|
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||||
|
|
||||||
// Sampling functions
|
// Sampling functions
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
|
@ -257,6 +283,9 @@ extern "C" {
|
||||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||||
|
|
||||||
|
/// @details Apply constraints from grammar
|
||||||
|
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
|
@ -278,6 +307,9 @@ extern "C" {
|
||||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
/// @details Accepts the sampled token into the grammar, possibly transforming to a new token
|
||||||
|
LLAMA_API llama_token llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||||
|
|
||||||
// Performance information
|
// Performance information
|
||||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue