mamba : support llama_kv_cache_seq_cp copy chains
* mamba : support shifting and dividing the kv cache pos
This commit is contained in:
parent
9473ec2147
commit
3dcf79824d
1 changed files with 34 additions and 35 deletions
67
llama.cpp
67
llama.cpp
|
@ -2184,18 +2184,17 @@ static bool llama_kv_cache_find_slot(
|
||||||
}
|
}
|
||||||
// Assuming the tokens are in-order
|
// Assuming the tokens are in-order
|
||||||
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
|
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
|
||||||
// What should happen when the pos backtracks?
|
// What should happen when the pos backtracks or skips a value?
|
||||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||||
LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n",
|
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
||||||
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
|
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
cache.cells[seq_id].pos = batch.pos[i];
|
cache.cells[seq_id].pos = batch.pos[i];
|
||||||
// NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
|
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
|
||||||
} else {
|
} else {
|
||||||
// too big seq_id
|
// too big seq_id
|
||||||
// TODO: would it be possible to resize the KV cache size instead?
|
// TODO: would it be possible to resize the KV cache size instead?
|
||||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2326,24 +2325,26 @@ static void llama_kv_cache_seq_cp(
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.unlimited) {
|
||||||
if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
|
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
||||||
// intent to "copy from" (does not support copy chains)
|
seq_id_src = cache.cells[seq_id_src].delta;
|
||||||
|
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
|
||||||
|
// intent to "copy from"
|
||||||
|
// supports copy chains thanks to taking the source of the source
|
||||||
cache.cells[seq_id_dst].delta = seq_id_src;
|
cache.cells[seq_id_dst].delta = seq_id_src;
|
||||||
// NOTE: a sequence can't have multiple sources, but can have multiple destinations.
|
|
||||||
// For compatibility with the other KV cache API functions,
|
// prevent the destination from getting cleared if the source is not empty
|
||||||
// the seq_id(s) of a cell suggests an intent to "copy to" those id(s),
|
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
|
||||||
// so that when a sequence is copied, it can initially be found from the source cell.
|
|
||||||
cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
|
|
||||||
// prevent the destination from getting cleared
|
|
||||||
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
|
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
|
||||||
|
}
|
||||||
// repurposed as a "need copy" flag
|
// repurposed as a "need copy" flag
|
||||||
// (shifting can't be done anyway for this kind of KV cache)
|
// (shifting can't be done anyway for this kind of KV cache)
|
||||||
cache.has_shift = seq_id_src != seq_id_dst;
|
cache.has_shift = true;
|
||||||
// NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
|
|
||||||
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
|
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// otherwise, this is the KV cache of a Transformer-like model
|
||||||
|
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
|
|
||||||
|
@ -2385,7 +2386,14 @@ static void llama_kv_cache_seq_add(
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.unlimited) {
|
||||||
GGML_ASSERT(false); // not supported
|
// for Mamba-like models, only the pos needs to be shifted
|
||||||
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||||
|
llama_kv_cell & cell = cache.cells[seq_id];
|
||||||
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
|
cell.pos += delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
|
@ -2422,7 +2430,14 @@ static void llama_kv_cache_seq_div(
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.unlimited) {
|
||||||
GGML_ASSERT(false); // not supported
|
// for Mamba-like models, only the pos needs to be changed
|
||||||
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||||
|
llama_kv_cell & cell = cache.cells[seq_id];
|
||||||
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
|
cell.pos /= d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
|
@ -8435,7 +8450,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_self.unlimited) {
|
if (kv_self.unlimited) {
|
||||||
const int64_t kv_size = kv_self.size;
|
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -8451,7 +8465,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
data[i] = (float) has_self_seq;
|
data[i] = (float) has_self_seq;
|
||||||
|
|
||||||
// ensure current sequences will be kept
|
// ensure current sequences will be kept
|
||||||
if (!has_self_seq) {
|
if (!has_self_seq && kv_cell.pos >= 0) {
|
||||||
kv_cell.seq_id.insert(seq_id);
|
kv_cell.seq_id.insert(seq_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8480,21 +8494,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// remove extraneous seq_ids when state copies are made
|
|
||||||
{
|
|
||||||
for (int i = 0; i < kv_size; ++i) {
|
|
||||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
|
|
||||||
uint32_t n_seqs = kv_cell.seq_id.size();
|
|
||||||
bool has_self_seq = kv_cell.has_seq_id(i);
|
|
||||||
|
|
||||||
if (has_self_seq && n_seqs > 1) {
|
|
||||||
kv_cell.seq_id.clear();
|
|
||||||
kv_cell.seq_id.insert(i);
|
|
||||||
} else if (!has_self_seq && n_seqs > 0) {
|
|
||||||
kv_cell.seq_id.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue