diff --git a/llama.cpp b/llama.cpp index eb1f02e42..fe24f5096 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2212,22 +2212,19 @@ static bool llama_kv_cache_find_slot( // For recurrent state architectures (like Mamba), // each KV cache cell can store the state for a whole sequence. - // starting point to find the minimum seq_id used in the batch - cache.head = cache.size - 1; - // likewise, to find the max seq_id in the batch - cache.used = 0; + llama_seq_id min = cache.size - 1; + llama_seq_id max = 0; + for (uint32_t i = 0; i < n_tokens; ++i) { for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; // make sure it's a valid seq_id - if ((uint32_t)seq_id < cache.size) { - // the number of "used" cells is simply the biggest seq_id - if (cache.used < (uint32_t)seq_id) { - cache.used = seq_id; + if ((uint32_t) seq_id < cache.size) { + if (seq_id > max) { + max = seq_id; } - // the "head" is the smallest seq_id - if (cache.head > (uint32_t)seq_id) { - cache.head = seq_id; + if (seq_id < min) { + min = seq_id; } // Assuming the tokens are in-order if (batch.pos[i] != cache.cells[seq_id].pos + 1) { @@ -2236,6 +2233,9 @@ static bool llama_kv_cache_find_slot( LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); } + if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { + cache.used += 1; + } cache.cells[seq_id].pos = batch.pos[i]; // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set } else { @@ -2247,9 +2247,12 @@ static bool llama_kv_cache_find_slot( } } - cache.n = cache.used - cache.head + 1; - // sanity check (max >= min) - return cache.used >= cache.head; + // allow getting the range of used cells, from head to head + n + cache.head = min; + cache.n = max - min + 1; + + // sanity check + return max >= min; } // otherwise, one cell per token.