mamba : make the server and parallel examples work with whole sequences
A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not
This commit is contained in:
parent
3dcf79824d
commit
34e2fca8eb
4 changed files with 70 additions and 33 deletions
|
@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
|
|||
// number of simultaneous "clients" to simulate
|
||||
const int32_t n_clients = params.n_parallel;
|
||||
|
||||
// dedicate one sequence to the system prompt
|
||||
params.n_parallel += 1;
|
||||
|
||||
// requests to simulate
|
||||
const int32_t n_seq = params.n_sequences;
|
||||
|
||||
|
@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
for (int32_t i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
@ -221,15 +224,17 @@ int main(int argc, char ** argv) {
|
|||
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
|
||||
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
|
||||
|
||||
client.n_decoded += 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 0; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, -1, -1);
|
||||
// but keep the system prompt
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: clearing the KV cache\n", __func__);
|
||||
|
@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
|
|||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
|
||||
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
|
@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
|
|
@ -377,12 +377,16 @@ struct llama_server_context
|
|||
return false;
|
||||
}
|
||||
|
||||
if (params.n_ctx < 2048) { // request larger context for the image embedding
|
||||
if (params.n_ctx != 0 && params.n_ctx < 2048) { // request larger context for the image embedding
|
||||
params.n_ctx = 2048;
|
||||
}
|
||||
}
|
||||
|
||||
// dedicate one sequence to the system prompt
|
||||
params.n_parallel += 1;
|
||||
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
params.n_parallel -= 1; // but be sneaky about it
|
||||
if (model == nullptr)
|
||||
{
|
||||
LOG_ERROR("unable to load model", {{"model", params.model}});
|
||||
|
@ -862,9 +866,9 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < params.n_parallel; ++i)
|
||||
for (int32_t i = 1; i <= params.n_parallel; ++i)
|
||||
{
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1351,7 +1355,7 @@ struct llama_server_context
|
|||
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
||||
for (int i = 0; i < (int) append_tokens.size(); ++i)
|
||||
{
|
||||
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
|
||||
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id + 1 }, true);
|
||||
slot.n_past += 1;
|
||||
}
|
||||
}
|
||||
|
@ -1587,8 +1591,8 @@ struct llama_server_context
|
|||
{"n_system_tokens", system_tokens.size()},
|
||||
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||
});
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
|
||||
{
|
||||
|
@ -1640,7 +1644,7 @@ struct llama_server_context
|
|||
|
||||
// TODO: we always have to take into account the "system_tokens"
|
||||
// this is not great and needs to be improved somehow
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
|
||||
slot.n_past += 1;
|
||||
}
|
||||
|
||||
|
@ -1810,13 +1814,28 @@ struct llama_server_context
|
|||
}
|
||||
}
|
||||
|
||||
// keep only the common part
|
||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
LOG_INFO("kv cache rm [p0, end)", {
|
||||
{ "slot_id", slot.id },
|
||||
{ "task_id", slot.task_id },
|
||||
{ "p0", p0 }
|
||||
});
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
// TODO: logging
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
||||
|
||||
// there is no common part left (except for the system prompt)
|
||||
// TODO: maybe find a way to refactor this to reuse the !cache_prompt case above
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
}
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
|
@ -1845,7 +1864,7 @@ struct llama_server_context
|
|||
ga_i += ga_w/ga_n;
|
||||
}
|
||||
}
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
|
||||
slot_npast++;
|
||||
}
|
||||
|
||||
|
@ -1899,9 +1918,9 @@ struct llama_server_context
|
|||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
|
||||
|
||||
slot.n_past_se -= bd;
|
||||
|
||||
|
|
38
llama.cpp
38
llama.cpp
|
@ -2275,7 +2275,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
|||
cache.used = 0;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_rm(
|
||||
static bool llama_kv_cache_seq_rm(
|
||||
struct llama_kv_cache & cache,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
|
@ -2285,11 +2285,23 @@ static void llama_kv_cache_seq_rm(
|
|||
if (p0 < 0) p0 = 0;
|
||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
// models like Mamba can't have a state partially erased
|
||||
if (cache.unlimited) {
|
||||
// can only remove whole sequences for models like Mamba
|
||||
GGML_ASSERT(p0 == 0);
|
||||
GGML_ASSERT((uint32_t)seq_id < cache.size);
|
||||
GGML_ASSERT(cache.cells[seq_id].pos < p1);
|
||||
if (seq_id >= (int64_t) cache.size) {
|
||||
// could be fatal
|
||||
return false;
|
||||
}
|
||||
if (0 <= seq_id) {
|
||||
// partial intersection is invalid
|
||||
if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// seq_id is negative, then the range should include everything or nothing
|
||||
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||
|
@ -2313,6 +2325,8 @@ static void llama_kv_cache_seq_rm(
|
|||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_seq_cp(
|
||||
|
@ -12514,13 +12528,11 @@ struct llama_context * llama_new_context_with_model(
|
|||
|
||||
// Mamba only needs a constant number of KV cache cells per sequence
|
||||
if (model->arch == LLM_ARCH_MAMBA) {
|
||||
// Mamba needs as many KV cells as there are sequences kept at any time
|
||||
// The extra cell allows dedicating a sequence id to the system prompt
|
||||
// TODO: find a better way to get the max number of parallel sequences
|
||||
kv_size = params.n_parallel + 1;
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
kv_size = std::max((uint32_t) 1, params.n_parallel);
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state
|
||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
|
@ -13039,8 +13051,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
|
|||
llama_kv_cache_clear(ctx->kv_self);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
|
||||
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
|
|
2
llama.h
2
llama.h
|
@ -503,7 +503,7 @@ extern "C" {
|
|||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_kv_cache_seq_rm(
|
||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue