speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -1,7 +1,6 @@
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
@ -195,17 +194,13 @@ struct llama_server_context
json prompt;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
llama_sampling_context ctx_sampling;
llama_sampling_context *ctx_sampling;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
bool truncated = false;
bool stopped_eos = false;
bool stopped_word = false;
@ -252,11 +247,10 @@ struct llama_server_context
n_remain = 0;
n_past = 0;
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
ctx_sampling = llama_sampling_context_init(params, NULL);
if (ctx_sampling != nullptr) {
llama_sampling_free(ctx_sampling);
}
ctx_sampling = llama_sampling_init(params);
}
bool loadModel(const gpt_params &params_)
@ -269,8 +263,6 @@ struct llama_server_context
return false;
}
n_ctx = llama_n_ctx(ctx);
last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
@ -321,27 +313,7 @@ struct llama_server_context
bool loadGrammar()
{
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return false;
}
grammar_parser::print_grammar(stderr, parsed_grammar);
{
auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
ctx_sampling = llama_sampling_context_init(params, grammar);
ctx_sampling = llama_sampling_init(params);
return true;
}
@ -383,7 +355,7 @@ struct llama_server_context
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
@ -398,8 +370,8 @@ struct llama_server_context
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
}
// compare the evaluated prompt with the new prompt
@ -443,7 +415,7 @@ struct llama_server_context
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
@ -458,8 +430,8 @@ struct llama_server_context
else
{
const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
}
// compare the evaluated prompt with the new prompt
@ -554,27 +526,24 @@ struct llama_server_context
{
// out of user input, sample next token
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL);
result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
const int32_t n_probs = params.sampling_params.n_probs;
if (params.sampling_params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
llama_sample_softmax(ctx, &cur_p);
}
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
{
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok);
llama_sampling_accept(ctx_sampling, ctx, result.tok);
if (tg) {
num_tokens_predicted++;
}
@ -1235,7 +1204,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
}
}
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
llama.ctx_sampling = llama_sampling_init(llama.params);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
@ -1793,9 +1762,7 @@ int main(int argc, char **argv)
return 1;
}
if (llama.grammar != nullptr) {
llama_grammar_free(llama.grammar);
}
llama_sampling_free(llama.ctx_sampling);
llama_backend_free();
return 0;