mamba : more correctly update the "used" field of the KV cache
This commit is contained in:
parent
206e8ee2b2
commit
1af1000f10
1 changed files with 17 additions and 14 deletions
29
llama.cpp
29
llama.cpp
|
@ -2212,22 +2212,19 @@ static bool llama_kv_cache_find_slot(
|
||||||
// For recurrent state architectures (like Mamba),
|
// For recurrent state architectures (like Mamba),
|
||||||
// each KV cache cell can store the state for a whole sequence.
|
// each KV cache cell can store the state for a whole sequence.
|
||||||
|
|
||||||
// starting point to find the minimum seq_id used in the batch
|
llama_seq_id min = cache.size - 1;
|
||||||
cache.head = cache.size - 1;
|
llama_seq_id max = 0;
|
||||||
// likewise, to find the max seq_id in the batch
|
|
||||||
cache.used = 0;
|
|
||||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||||
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
|
||||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||||
// make sure it's a valid seq_id
|
// make sure it's a valid seq_id
|
||||||
if ((uint32_t) seq_id < cache.size) {
|
if ((uint32_t) seq_id < cache.size) {
|
||||||
// the number of "used" cells is simply the biggest seq_id
|
if (seq_id > max) {
|
||||||
if (cache.used < (uint32_t)seq_id) {
|
max = seq_id;
|
||||||
cache.used = seq_id;
|
|
||||||
}
|
}
|
||||||
// the "head" is the smallest seq_id
|
if (seq_id < min) {
|
||||||
if (cache.head > (uint32_t)seq_id) {
|
min = seq_id;
|
||||||
cache.head = seq_id;
|
|
||||||
}
|
}
|
||||||
// Assuming the tokens are in-order
|
// Assuming the tokens are in-order
|
||||||
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
|
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
|
||||||
|
@ -2236,6 +2233,9 @@ static bool llama_kv_cache_find_slot(
|
||||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
||||||
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
|
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
|
||||||
}
|
}
|
||||||
|
if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
|
||||||
|
cache.used += 1;
|
||||||
|
}
|
||||||
cache.cells[seq_id].pos = batch.pos[i];
|
cache.cells[seq_id].pos = batch.pos[i];
|
||||||
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
|
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
|
||||||
} else {
|
} else {
|
||||||
|
@ -2247,9 +2247,12 @@ static bool llama_kv_cache_find_slot(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cache.n = cache.used - cache.head + 1;
|
// allow getting the range of used cells, from head to head + n
|
||||||
// sanity check (max >= min)
|
cache.head = min;
|
||||||
return cache.used >= cache.head;
|
cache.n = max - min + 1;
|
||||||
|
|
||||||
|
// sanity check
|
||||||
|
return max >= min;
|
||||||
}
|
}
|
||||||
// otherwise, one cell per token.
|
// otherwise, one cell per token.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue