speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
void llama_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos,
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);

View file

@ -70,6 +70,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter
@ -124,10 +125,23 @@ void process_escapes(std::string& input);
// Model utils
//
// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
// Batch utils
void llama_batch_clear(struct llama_batch & batch);
void llama_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Vocab utils
//

View file

@ -579,38 +579,75 @@ inline std::string log_var_to_string_impl(const std::vector<int> & var)
return buf.str();
}
#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \
[&tokens, &ctx]() \
{ \
std::stringstream buf; \
buf << "[ "; \
\
bool first = true; \
for (const auto &token : tokens) \
{ \
if (!first) \
buf << ", "; \
else \
first = false; \
\
auto detokenized = llama_token_to_piece(ctx, token); \
\
detokenized.erase( \
std::remove_if( \
detokenized.begin(), \
detokenized.end(), \
[](const unsigned char c) { return !std::isprint(c); }), \
detokenized.end()); \
\
buf \
<< "'" << detokenized << "'" \
<< ":" << std::to_string(token); \
} \
buf << " ]"; \
\
return buf.str(); \
}() \
.c_str()
template <typename C, typename T>
inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
{
std::stringstream buf;
buf << "[ ";
bool first = true;
for (const auto &token : tokens)
{
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = llama_token_to_piece(ctx, token);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf
<< "'" << detokenized << "'"
<< ":" << std::to_string(token);
}
buf << " ]";
return buf.str();
}
template <typename C, typename B>
inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
{
std::stringstream buf;
buf << "[ ";
bool first = true;
for (int i = 0; i < batch.n_tokens; ++i)
{
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = llama_token_to_piece(ctx, batch.token[i]);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf
<< "\n" << std::to_string(i)
<< ":token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
return buf.str();
}
#ifdef LOG_DISABLE_LOGS

View file

@ -1,64 +1,81 @@
#include "sampling.h"
llama_sampling_context::~llama_sampling_context() {
for (auto & it : sequence_contexts) {
if (it.second.grammar != NULL) {
llama_grammar_free(it.second.grammar);
it.second.grammar = NULL;
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
result->params = params.sampling_params;
result->grammar = nullptr;
// if there is a grammar, parse it
if (!params.grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr;
}
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
result->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
result->prev.resize(params.n_ctx);
return result;
}
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar) {
llama_sampling_context result;
void llama_sampling_free(struct llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
result.params = params.sampling_params;
result.grammar = grammar;
return result;
delete ctx;
}
// Note: Creates the context if it doesn't exist, so this always return something.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it != ctx_sampling.sequence_contexts.end()) {
return it->second;
void llama_sampling_reset(llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
llama_sampler_sequence_context new_ctx = {
2.0f * ctx_sampling.params.mirostat_tau,
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
};
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
if (!ctx->parsed_grammar.rules.empty()) {
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
ctx->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
}
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it == ctx_sampling.sequence_contexts.end()) return false;
if (it->second.grammar != NULL) {
llama_grammar_free(it->second.grammar);
it->second.grammar = NULL;
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
}
ctx_sampling.sequence_contexts.erase(it);
return true;
if (src->grammar) {
dst->grammar = llama_grammar_copy(src->grammar);
}
dst->prev = src->prev;
}
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx,
llama_seq_id seq) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int n_ctx = llama_n_ctx(ctx_main);
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const llama_sampling_params & params = ctx_sampling->params;
const llama_sampling_params & params = ctx_sampling.params;
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p;
@ -73,41 +90,45 @@ llama_token llama_sampling_sample(
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
llama_token id = 0;
float * logits = llama_get_logits_ith(ctx, idx);
float * logits = llama_get_logits_ith(ctx_main, idx);
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
candidates.clear();
cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
if (ctx_cfg) {
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_cfg, params.cfg_scale);
}
// apply penalties
if (!last_tokens.empty()) {
const float nl_logit = logits[llama_token_nl(ctx)];
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)];
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
llama_sample_repetition_penalty(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat,
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
cur_p.data[idx].logit = nl_logit;
break;
}
@ -115,52 +136,58 @@ llama_token llama_sampling_sample(
}
}
llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
if (ctx_seq.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
if (ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p);
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
llama_sample_temp(ctx, &cur_p, temp);
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
llama_sample_temp (ctx_main, &cur_p, temp);
{
const int n_top = 10;
LOG("top %d candidates:\n", n_top);
id = llama_sample_token(ctx_main, &cur_p);
for (int i = 0; i < n_top; i++) {
const llama_token id = cur_p.data[i].id;
(void)id; // To avoid a warning that id is unused when logging is disabled.
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
}
}
//{
// const int n_top = 10;
// LOG("top %d candidates:\n", n_top);
id = llama_sample_token(ctx, &cur_p);
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
// }
//}
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
}
}
if (ctx_seq.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
}
return id;
}
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}

View file

@ -2,6 +2,8 @@
#include "llama.h"
#include "grammar-parser.h"
#include <string>
#include <vector>
#include <unordered_map>
@ -34,75 +36,64 @@ typedef struct llama_sampling_params {
} llama_sampling_params;
// per-sequence sampler context
typedef struct llama_sampler_sequence_context {
float mirostat_mu; // mirostat sampler state
llama_grammar * grammar;
} llama_sampler_sequence_context;
// general sampler context
typedef struct llama_sampling_context {
~llama_sampling_context();
// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_context instances
// TODO: move to llama.h
struct llama_sampling_context {
// parameters that will be used for sampling
llama_sampling_params params;
// map of sequence ids to sampler contexts
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
// mirostat sampler state
float mirostat_mu;
// when non-NULL, new instances of llama_sampler_sequence_context
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar;
} llama_sampling_context;
// internal
grammar_parser::parse_state parsed_grammar;
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
#include "common.h"
// Create a new sampling context instance.
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
// Fetches the sampler context for the specified sequence id (defaults to 0).
// If the context for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the ctx_sampling argument.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
void llama_sampling_free(struct llama_sampling_context * ctx);
// Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
// Reset the sampler context
// - clear prev tokens
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
// llama_sampling_context_reset when a sequence ends
// llama_sampling_reset when a sequence ends
//
// required:
// - ctx: context to use for sampling
// - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with
// - ctx_cfg: context to use for classifier-free guidance
// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx = 0,
llama_seq_id seq = 0);
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id);