speculative : manage context in common_speculative

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-21 21:27:14 +02:00
parent fe043ff1ff
commit 0f878a657c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 188 additions and 144 deletions

View file

@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
[](const unsigned char c) { return !std::isprint(c); }), [](const unsigned char c) { return !std::isprint(c); }),
detokenized.end()); detokenized.end());
buf << "\n" << std::to_string(i) buf << "\n" << std::to_string(i)
<< ":token '" << detokenized << "'" << ", token '" << detokenized << "'"
<< ":pos " << std::to_string(batch.pos[i]) << ", pos " << std::to_string(batch.pos[i])
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i]) << ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ":seq_id " << std::to_string(batch.seq_id[i][0]) << ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ":logits " << std::to_string(batch.logits[i]); << ", logits " << std::to_string(batch.logits[i]);
} }
buf << " ]"; buf << " ]";
@ -1490,6 +1490,66 @@ void common_batch_add(
batch.n_tokens++; batch.n_tokens++;
} }
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
// //
// Vocab utils // Vocab utils
// //

View file

@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter; struct llama_lora_adapter * adapter;
}; };
using llama_tokens = std::vector<llama_token>;
// build info // build info
extern int LLAMA_BUILD_NUMBER; extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT; extern char const * LLAMA_COMMIT;
@ -461,7 +463,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters); void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
//
// Batch utils // Batch utils
//
void common_batch_clear(struct llama_batch & batch); void common_batch_clear(struct llama_batch & batch);
@ -472,6 +476,16 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids, const std::vector<llama_seq_id> & seq_ids,
bool logits); bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
// //
// Vocab utils // Vocab utils
// //

View file

@ -342,6 +342,28 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
return result; return result;
} }
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) {
std::vector<int> idxs;
idxs.reserve(batch.n_tokens);
std::vector<llama_token> draft;
draft.reserve(batch.n_tokens);
for (int i = 0; i < batch.n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
if (idxs.size() > 0) {
GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]);
draft.push_back(batch.token[i]);
}
idxs.push_back(i);
}
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain); return llama_sampler_get_seed(gsmpl->chain);
} }

View file

@ -73,6 +73,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// //
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false); std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers // helpers

View file

@ -11,9 +11,7 @@ struct common_speculative {
struct common_sampler * smpl; struct common_sampler * smpl;
std::vector<int> i_batch_tgt; llama_tokens prompt_last;
std::vector<llama_token> tokens;
}; };
struct common_speculative * common_speculative_init(struct common_speculative_params params) { struct common_speculative * common_speculative_init(struct common_speculative_params params) {
@ -21,12 +19,10 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
/* .params = */ params, /* .params = */ params,
/* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1),
/* .smpl = */ nullptr, /* .smpl = */ nullptr,
/* .i_batch_tgt = */ {},
/* .tokens = */ {},
}; };
// TODO: optimize or pass from outside? // TODO: optimize or pass from outside?
#if 0 #if 1
{ {
common_sampler_params sparams; common_sampler_params sparams;
sparams.no_perf = false; sparams.no_perf = false;
@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) {
delete spec; delete spec;
} }
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) {
llama_kv_cache_clear(spec->params.ctx_dft);
// TODO: error handling
llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens));
}
void common_speculative_add_draft( void common_speculative_add_draft(
struct common_speculative * spec, struct common_speculative * spec,
struct llama_batch & batch_tgt, struct llama_batch & batch_tgt,
const llama_tokens & prompt,
llama_token id_last, llama_token id_last,
int n_past) { llama_token n_past_tgt) {
spec->tokens.clear();
spec->i_batch_tgt.clear(); int reuse_i = 0;
spec->i_batch_tgt.push_back(0); int reuse_n = 0;
common_sampler_reset(spec->smpl); const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft;
const int i_start = std::max<int>(0, (int) prompt.size() - n_ctx);
for (int i = 0; i < (int) spec->prompt_last.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt.size() &&
i + cur < (int) spec->prompt_last.size() &&
prompt[i_start + cur] == spec->prompt_last[i + cur]) {
cur++;
}
if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) {
reuse_i = i;
reuse_n = cur;
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n);
if (reuse_n == 0) {
llama_kv_cache_clear(spec->params.ctx_dft);
spec->prompt_last.clear();
} else {
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i);
llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1);
llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i);
spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i);
spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end());
}
common_batch_clear(spec->batch_dft);
for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]);
common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false);
spec->prompt_last.push_back(prompt[i]);
}
const llama_pos n_past = prompt.size() - i_start;
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
if (spec->batch_dft.n_tokens > 0) {
LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str());
llama_decode(spec->params.ctx_dft, spec->batch_dft);
}
common_batch_clear(spec->batch_dft); common_batch_clear(spec->batch_dft);
common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true);
spec->prompt_last.push_back(id_last);
LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str());
llama_decode(spec->params.ctx_dft, spec->batch_dft); llama_decode(spec->params.ctx_dft, spec->batch_dft);
common_sampler_reset(spec->smpl);
// sample n_draft tokens from the draft model // sample n_draft tokens from the draft model
for (int i = 0; i < spec->params.n_draft; ++i) { for (int i = 0; i < spec->params.n_draft; ++i) {
common_batch_clear(spec->batch_dft); common_batch_clear(spec->batch_dft);
@ -111,18 +156,13 @@ void common_speculative_add_draft(
const llama_token id = cur_p->data[0].id; const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens // only collect very high-confidence draft tokens
if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) { if (cur_p->data[0].p < spec->params.p_min) {
break; break;
} }
common_sampler_accept(spec->smpl, id, true); common_sampler_accept(spec->smpl, id, true);
spec->tokens.push_back(id); common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true);
// add unique drafted tokens to the target batch
spec->i_batch_tgt.push_back(batch_tgt.n_tokens);
common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true);
if (batch_tgt.n_tokens > spec->params.n_draft) { if (batch_tgt.n_tokens > spec->params.n_draft) {
break; break;
@ -132,23 +172,13 @@ void common_speculative_add_draft(
// evaluate the drafted tokens on the draft model // evaluate the drafted tokens on the draft model
llama_decode(spec->params.ctx_dft, spec->batch_dft); llama_decode(spec->params.ctx_dft, spec->batch_dft);
spec->prompt_last.push_back(id);
} }
// don't waste time on small batches // don't waste time on small batches
// TODO: do not evaluate the draft model for that many rounds // TODO: do not evaluate the draft model for that many rounds
if (batch_tgt.n_tokens < spec->params.n_min) { if (batch_tgt.n_tokens < spec->params.n_min) {
batch_tgt.n_tokens = 1; batch_tgt.n_tokens = 1;
spec->tokens.resize(0);
spec->i_batch_tgt.resize(1);
} }
// print current draft sequences
LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str());
}
std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt) {
return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens);
} }

View file

@ -1,14 +1,16 @@
#pragma once #pragma once
#include "llama.h" #include "llama.h"
#include "common.h"
#include <vector>
struct common_speculative; struct common_speculative;
struct common_speculative_params { struct common_speculative_params {
int n_draft = 16; int n_draft = 16;
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
int n_reuse = 256;
float p_min = 0.9f;
struct llama_model * model_dft = nullptr; struct llama_model * model_dft = nullptr;
@ -19,28 +21,11 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
void common_speculative_free(struct common_speculative * spec); void common_speculative_free(struct common_speculative * spec);
// TODO: remove
void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens);
// sample up to n_draft tokens and add them to the batch using the draft model // sample up to n_draft tokens and add them to the batch using the draft model
// //
// TODO: change to:
//
// void common_speculative_add_draft(
// struct common_speculative * spec,
// struct llama_batch & batch_tgt,
// llama_token * tokens,
// int32_t n_tokens);
//
// and update the internal logic to compute only the new tokens
//
void common_speculative_add_draft( void common_speculative_add_draft(
struct common_speculative * spec, struct common_speculative * spec,
struct llama_batch & batch_tgt, struct llama_batch & batch_tgt,
const llama_tokens & prompt,
llama_token id_last, llama_token id_last,
int n_past); llama_token n_past_tgt);
std::vector<llama_token> common_speculative_sample(
struct common_speculative * spec,
struct common_sampler * smpl,
struct llama_context * ctx_tgt);

View file

@ -743,7 +743,7 @@ struct server_context {
} }
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens); int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
// fraction of the common subsequence length compared to the current slot's prompt length // fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size()); float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@ -1960,7 +1960,7 @@ struct server_context {
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt // reuse any previously computed tokens that are common with the new prompt
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position // reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params.n_cache_reuse > 0) { if (params.n_cache_reuse > 0) {

View file

@ -24,7 +24,6 @@
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
using llama_tokens = std::vector<llama_token>;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
@ -439,62 +438,6 @@ static std::string gen_chatcmplid() {
// other common utils // other common utils
// //
static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
static bool ends_with(const std::string & str, const std::string & suffix) { static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
} }

View file

@ -14,14 +14,6 @@
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
struct seq_draft {
std::vector<int> i_batch_tgt;
std::vector<llama_token> tokens;
struct common_sampler * smpl = nullptr;
};
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
common_params params; common_params params;
@ -165,27 +157,21 @@ int main(int argc, char ** argv) {
// note: keep the last token separate! // note: keep the last token separate!
llama_token id_last = inp.back(); llama_token id_last = inp.back();
auto prompt_dft = std::vector<llama_token>(inp.begin(), inp.end() - 1);
int n_past = inp.size() - 1; int n_past = inp.size() - 1;
// init the speculator // init the speculator
struct common_speculative_params params_spec; struct common_speculative_params params_spec;
params_spec.n_draft = n_draft; params_spec.n_draft = n_draft;
params_spec.n_min = 5; params_spec.n_min = 5;
params_spec.n_reuse = 256;
params_spec.p_min = 0.9f;
params_spec.model_dft = model_dft; params_spec.model_dft = model_dft;
params_spec.ctx_dft = ctx_dft; params_spec.ctx_dft = ctx_dft;
struct common_speculative * spec = common_speculative_init(params_spec); struct common_speculative * spec = common_speculative_init(params_spec);
// feed the prompt to the speculator
//
// this has to be kept synchronized with the target context
//
// TODO: simplify this by moving the context management logic in the common_speculative instance
// for example, the common_speculative_add_draft can pass the entire context (or part of it) and the
// speculator will automatically compute any new tokens that are not present in its context
//
common_speculative_set_prompt(spec, inp.data(), n_input - 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
@ -204,7 +190,7 @@ int main(int argc, char ** argv) {
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables. // from a cache or lookup tables.
// //
common_speculative_add_draft(spec, batch_tgt, id_last, n_past); common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{ {
@ -220,7 +206,7 @@ int main(int argc, char ** argv) {
// available logits from the batch and sample the next token until we run out of logits or the sampler // available logits from the batch and sample the next token until we run out of logits or the sampler
// disagrees with the draft // disagrees with the draft
// //
const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt);
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
@ -266,9 +252,11 @@ int main(int argc, char ** argv) {
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1);
} }
prompt_dft.push_back(id_last);
prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1);
// remember the last accepted token for the next iteration // remember the last accepted token for the next iteration
id_last = id; id_last = id;
} }