speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -257,12 +257,12 @@ int main(int argc, char ** argv) {
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
params.n_keep = (int)embd_inp.size();
}
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// enable interactive mode if interactive start is specified
@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
// TODO: replace with ring-buffer
std::vector<llama_token> last_tokens(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
const int n_vocab = llama_n_vocab(model);
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while (n_remain != 0 || params.interactive) {
// predict
@ -484,7 +477,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
}
@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@ -554,12 +547,11 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
llama_sampling_accept(ctx_sampling, ctx, id);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
else if (last_tokens.back() == llama_token_eos(ctx)) {
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
const size_t original_size = embd_inp.size();
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());