mamba : in comments, properly refer to KV cells instead of slots

This commit is contained in:
Francis Couture-Harpin 2024-02-14 13:43:14 -05:00
parent 8a43ffcfa1
commit e73eaa7b4f

View file

@ -1802,7 +1802,7 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
// with Mamba, a slot can hold the state for more than one past token
// with Mamba, a cell can hold the state for more than one past token
bool unlimited = false;
// Note: The value of head isn't only used to optimize searching
@ -2069,7 +2069,7 @@ static bool llama_kv_cache_init(
cache.has_shift = false;
// for now, only Mamba can hold state for more than one past token per slot
// for now, only Mamba can hold state for more than one past token per cell
cache.unlimited = model.arch == LLM_ARCH_MAMBA;
cache.head = 0;
@ -2330,7 +2330,7 @@ static void llama_kv_cache_seq_cp(
cache.cells[seq_id_dst].delta = seq_id_src;
// NOTE: a sequence can't have multiple sources, but can have multiple destinations.
// For compatibility with the other KV cache API functions,
// the seq_id(s) of a slot suggests an intent to "copy to" those id(s),
// the seq_id(s) of a cell suggests an intent to "copy to" those id(s),
// so that when a sequence is copied, it can initially be found from the source cell.
cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
// prevent the destination from getting cleared
@ -12504,10 +12504,10 @@ struct llama_context * llama_new_context_with_model(
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache slots per sequence
// Mamba only needs a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA) {
// Mamba needs as many slots as there are distinct sequences processed at the same time
// The extra slot allows dedicating a sequence id to the system prompt
// Mamba needs as many KV cells as there are sequences kept at any time
// The extra cell allows dedicating a sequence id to the system prompt
// TODO: find a better way to get the max number of parallel sequences
kv_size = params.n_parallel + 1;
// it's probably best to keep as much precision as possible for the states