speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
This commit is contained in:
parent
c67fe68e41
commit
0e89203b51
21 changed files with 737 additions and 578 deletions
|
@ -1,7 +1,6 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "build-info.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#ifndef NDEBUG
|
||||
// crash the server in debug mode, otherwise send an http 500 error
|
||||
|
@ -195,17 +194,13 @@ struct llama_server_context
|
|||
|
||||
json prompt;
|
||||
std::vector<llama_token> embd;
|
||||
std::vector<llama_token> last_n_tokens;
|
||||
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
gpt_params params;
|
||||
llama_sampling_context ctx_sampling;
|
||||
llama_sampling_context *ctx_sampling;
|
||||
int n_ctx;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
llama_grammar *grammar = nullptr;
|
||||
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
bool stopped_word = false;
|
||||
|
@ -252,11 +247,10 @@ struct llama_server_context
|
|||
n_remain = 0;
|
||||
n_past = 0;
|
||||
|
||||
if (grammar != nullptr) {
|
||||
llama_grammar_free(grammar);
|
||||
grammar = nullptr;
|
||||
ctx_sampling = llama_sampling_context_init(params, NULL);
|
||||
if (ctx_sampling != nullptr) {
|
||||
llama_sampling_free(ctx_sampling);
|
||||
}
|
||||
ctx_sampling = llama_sampling_init(params);
|
||||
}
|
||||
|
||||
bool loadModel(const gpt_params ¶ms_)
|
||||
|
@ -269,8 +263,6 @@ struct llama_server_context
|
|||
return false;
|
||||
}
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
last_n_tokens.resize(n_ctx);
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -321,27 +313,7 @@ struct llama_server_context
|
|||
|
||||
bool loadGrammar()
|
||||
{
|
||||
if (!params.grammar.empty()) {
|
||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
|
||||
return false;
|
||||
}
|
||||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||
|
||||
{
|
||||
auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
|
||||
if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
|
||||
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
ctx_sampling = llama_sampling_init(params);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -383,7 +355,7 @@ struct llama_server_context
|
|||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
|
||||
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", params.n_ctx},
|
||||
|
@ -398,8 +370,8 @@ struct llama_server_context
|
|||
else
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
|
||||
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
|
@ -443,7 +415,7 @@ struct llama_server_context
|
|||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
||||
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
|
||||
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", n_ctx},
|
||||
|
@ -458,8 +430,8 @@ struct llama_server_context
|
|||
else
|
||||
{
|
||||
const size_t ps = num_prompt_tokens;
|
||||
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
|
||||
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
|
||||
}
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
|
@ -554,27 +526,24 @@ struct llama_server_context
|
|||
|
||||
{
|
||||
// out of user input, sample next token
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(llama_n_vocab(model));
|
||||
result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL);
|
||||
|
||||
result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
|
||||
|
||||
const int32_t n_probs = params.sampling_params.n_probs;
|
||||
if (params.sampling_params.temp <= 0 && n_probs > 0)
|
||||
{
|
||||
// For llama_sample_token_greedy we need to sort candidates
|
||||
llama_sample_softmax(ctx, &candidates_p);
|
||||
llama_sample_softmax(ctx, &cur_p);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
|
||||
for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
|
||||
{
|
||||
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
|
||||
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
||||
}
|
||||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(result.tok);
|
||||
llama_sampling_accept(ctx_sampling, ctx, result.tok);
|
||||
|
||||
if (tg) {
|
||||
num_tokens_predicted++;
|
||||
}
|
||||
|
@ -1235,7 +1204,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||
}
|
||||
}
|
||||
|
||||
llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
|
||||
llama.ctx_sampling = llama_sampling_init(llama.params);
|
||||
|
||||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
||||
}
|
||||
|
@ -1793,9 +1762,7 @@ int main(int argc, char **argv)
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (llama.grammar != nullptr) {
|
||||
llama_grammar_free(llama.grammar);
|
||||
}
|
||||
llama_sampling_free(llama.ctx_sampling);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue