From 3dcf79824d33806b3462aa4b55ebef5477aa1a81 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 25 Feb 2024 09:51:49 -0500 Subject: [PATCH] mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos --- llama.cpp | 69 +++++++++++++++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/llama.cpp b/llama.cpp index ad0322674..7281972b3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2184,18 +2184,17 @@ static bool llama_kv_cache_find_slot( } // Assuming the tokens are in-order if (batch.pos[i] != cache.cells[seq_id].pos + 1) { - // What should happen when the pos backtracks? + // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n", + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); - return false; } cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here, because they are handled when the graph is built. + // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set } else { // too big seq_id // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size); + LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); return false; } } @@ -2326,24 +2325,26 @@ static void llama_kv_cache_seq_cp( if (p1 < 0) p1 = std::numeric_limits::max(); if (cache.unlimited) { - if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) { - // intent to "copy from" (does not support copy chains) + if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { + seq_id_src = cache.cells[seq_id_src].delta; + GGML_ASSERT((uint32_t) seq_id_src < cache.size); + // intent to "copy from" + // supports copy chains thanks to taking the source of the source cache.cells[seq_id_dst].delta = seq_id_src; - // NOTE: a sequence can't have multiple sources, but can have multiple destinations. - // For compatibility with the other KV cache API functions, - // the seq_id(s) of a cell suggests an intent to "copy to" those id(s), - // so that when a sequence is copied, it can initially be found from the source cell. - cache.cells[seq_id_src].seq_id.insert(seq_id_dst); - // prevent the destination from getting cleared - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); + + // prevent the destination from getting cleared if the source is not empty + if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { + cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); + } // repurposed as a "need copy" flag // (shifting can't be done anyway for this kind of KV cache) - cache.has_shift = seq_id_src != seq_id_dst; - // NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet) + cache.has_shift = true; + cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; } return; } + // otherwise, this is the KV cache of a Transformer-like model cache.head = 0; @@ -2385,7 +2386,14 @@ static void llama_kv_cache_seq_add( if (p1 < 0) p1 = std::numeric_limits::max(); if (cache.unlimited) { - GGML_ASSERT(false); // not supported + // for Mamba-like models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + llama_kv_cell & cell = cache.cells[seq_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + return; } for (uint32_t i = 0; i < cache.size; ++i) { @@ -2422,7 +2430,14 @@ static void llama_kv_cache_seq_div( if (p1 < 0) p1 = std::numeric_limits::max(); if (cache.unlimited) { - GGML_ASSERT(false); // not supported + // for Mamba-like models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + llama_kv_cell & cell = cache.cells[seq_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + return; } for (uint32_t i = 0; i < cache.size; ++i) { @@ -8435,7 +8450,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } if (kv_self.unlimited) { - const int64_t kv_size = kv_self.size; const int64_t n_kv = kv_self.n; { @@ -8451,7 +8465,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data[i] = (float) has_self_seq; // ensure current sequences will be kept - if (!has_self_seq) { + if (!has_self_seq && kv_cell.pos >= 0) { kv_cell.seq_id.insert(seq_id); } } @@ -8480,21 +8494,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } - // remove extraneous seq_ids when state copies are made - { - for (int i = 0; i < kv_size; ++i) { - llama_kv_cell & kv_cell = lctx.kv_self.cells[i]; - uint32_t n_seqs = kv_cell.seq_id.size(); - bool has_self_seq = kv_cell.has_seq_id(i); - - if (has_self_seq && n_seqs > 1) { - kv_cell.seq_id.clear(); - kv_cell.seq_id.insert(i); - } else if (!has_self_seq && n_seqs > 0) { - kv_cell.seq_id.clear(); - } - } - } } }