mamba : stop abusing attention metadata

This breaks existing converted-to-GGUF Mamba models,
but will allow supporting mixed architectures like MambaFormer
without needing to break Mamba models.

This will also allow changing the size of Mamba's states
without having to reconvert models in the future.
(e.g. using something else than d_conv - 1 columns for the conv_states
 will not require breaking existing converted Mamba models again)

* gguf-py : add new KV metadata key-value pairs for Mamba

* llama : add new metadata key-value pairs for Mamba

* llama : guard against divisions by zero when n_head is 0

* mamba : rename "unlimited" KV cache property to "recurrent"
This commit is contained in:
Francis Couture-Harpin 2024-02-28 10:58:17 -05:00
parent 8f605cfe0d
commit 206e8ee2b2
4 changed files with 128 additions and 49 deletions

View file

@ -1857,7 +1857,14 @@ class MambaModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
d_model = self.hparams["d_model"] d_model = self.hparams["d_model"]
d_conv = self.hparams.get("d_conv", 4)
d_inner = self.hparams.get("d_inner", 2 * d_model) d_inner = self.hparams.get("d_inner", 2 * d_model)
d_state = self.hparams.get("d_state", 16)
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
# Fail early for models which don't have a block expansion factor of 2 # Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model assert d_inner == 2 * d_model
@ -1865,13 +1872,13 @@ class MambaModel(Model):
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_ssm_conv_kernel_size(d_conv)
self.gguf_writer.add_ssm_inner_length(d_inner)
self.gguf_writer.add_ssm_state_length(d_state)
self.gguf_writer.add_ssm_dt_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
# NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
# Since the first column of the conv_state is shifted out each time, it's not actually needed
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self): def write_tensors(self):

View file

@ -61,6 +61,12 @@ class Keys:
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
class SSM:
CONV_KERNEL_SIZE = "{arch}.ssm.d_conv"
INNER_LENGTH = "{arch}.ssm.d_inner"
STATE_LENGTH = "{arch}.ssm.d_state"
DT_RANK = "{arch}.ssm.dt_rank"
class Tokenizer: class Tokenizer:
MODEL = "tokenizer.ggml.model" MODEL = "tokenizer.ggml.model"
LIST = "tokenizer.ggml.tokens" LIST = "tokenizer.ggml.tokens"
@ -763,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
# SSM
KEY_SSM_CONV_KERNEL_SIZE = Keys.SSM.CONV_KERNEL_SIZE
KEY_SSM_INNER_LENGTH = Keys.SSM.INNER_LENGTH
KEY_SSM_STATE_LENGTH = Keys.SSM.STATE_LENGTH
KEY_SSM_DT_RANK = Keys.SSM.DT_RANK
# tokenization # tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST

View file

@ -382,6 +382,18 @@ class GGUFWriter:
def add_rope_scaling_finetuned(self, value: bool) -> None: def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value) self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
def add_ssm_conv_kernel_size(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL_SIZE.format(arch=self.arch), value)
def add_ssm_inner_length(self, value: int) -> None:
self.add_uint32(Keys.SSM.INNER_LENGTH.format(arch=self.arch), value)
def add_ssm_state_length(self, value: int) -> None:
self.add_uint32(Keys.SSM.STATE_LENGTH.format(arch=self.arch), value)
def add_ssm_dt_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.DT_RANK.format(arch=self.arch), value)
def add_tokenizer_model(self, model: str) -> None: def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model) self.add_string(Keys.Tokenizer.MODEL, model)

134
llama.cpp
View file

@ -286,6 +286,11 @@ enum llm_kv {
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED, LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_SSM_D_INNER,
LLM_KV_SSM_D_CONV,
LLM_KV_SSM_D_STATE,
LLM_KV_SSM_DT_RANK,
LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_LIST,
LLM_KV_TOKENIZER_TOKEN_TYPE, LLM_KV_TOKENIZER_TOKEN_TYPE,
@ -344,6 +349,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_SSM_D_CONV, "%s.ssm.d_conv" },
{ LLM_KV_SSM_D_INNER, "%s.ssm.d_inner"},
{ LLM_KV_SSM_D_STATE, "%s.ssm.d_state"},
{ LLM_KV_SSM_DT_RANK, "%s.ssm.dt_rank"},
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
@ -1638,6 +1648,12 @@ struct llama_hparams {
float rope_freq_scale_train; float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx; uint32_t n_yarn_orig_ctx;
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
float f_clamp_kqv = 0.0f; float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f; float f_max_alibi_bias = 0.0f;
@ -1666,6 +1682,11 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true; if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
if (this->ssm_d_conv != other.ssm_d_conv) return true;
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
const float EPSILON = 1e-9f; const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
@ -1677,6 +1698,9 @@ struct llama_hparams {
} }
uint32_t n_gqa() const { uint32_t n_gqa() const {
if (n_head_kv == 0) {
return 0;
}
return n_head/n_head_kv; return n_head/n_head_kv;
} }
@ -1687,6 +1711,18 @@ struct llama_hparams {
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
return n_embd_head_v * n_head_kv; return n_embd_head_v * n_head_kv;
} }
uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings
// corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
}
uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings
// corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner;
}
}; };
struct llama_cparams { struct llama_cparams {
@ -1804,8 +1840,8 @@ struct llama_kv_cache {
bool has_shift = false; bool has_shift = false;
bool do_defrag = false; bool do_defrag = false;
bool do_copy = false; bool do_copy = false;
// with Mamba, a cell can hold the state for more than one past token // with recurrent state models, a cell can hold the state for more than one past token
bool unlimited = false; bool recurrent = false;
// Note: The value of head isn't only used to optimize searching // Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it // for a free KV slot. llama_decode_internal also uses it, so it
@ -2067,14 +2103,21 @@ static bool llama_kv_cache_init(
bool offload) { bool offload) {
const struct llama_hparams & hparams = model.hparams; const struct llama_hparams & hparams = model.hparams;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
cache.has_shift = false; cache.has_shift = false;
// for now, only Mamba can hold state for more than one past token per cell // TODO: find a nicer way to add other recurrent model architectures
cache.unlimited = model.arch == LLM_ARCH_MAMBA; cache.recurrent = model.arch == LLM_ARCH_MAMBA;
// TODO: support mixed reccurent Transformer architectues
// NOTE: (!a || b) is a logical implication (a -> b)
GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
cache.head = 0; cache.head = 0;
cache.size = kv_size; cache.size = kv_size;
@ -2086,7 +2129,8 @@ static bool llama_kv_cache_init(
cache.cells.clear(); cache.cells.clear();
cache.cells.resize(kv_size); cache.cells.resize(kv_size);
if (cache.unlimited) { if (cache.recurrent) {
// init state copy sources
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].src = i; cache.cells[i].src = i;
} }
@ -2164,8 +2208,8 @@ static bool llama_kv_cache_find_slot(
const uint32_t n_ctx = cache.size; const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens; const uint32_t n_tokens = batch.n_tokens;
if (cache.unlimited) { if (cache.recurrent) {
// For unlimited context architectures (like Mamba), // For recurrent state architectures (like Mamba),
// each KV cache cell can store the state for a whole sequence. // each KV cache cell can store the state for a whole sequence.
// starting point to find the minimum seq_id used in the batch // starting point to find the minimum seq_id used in the batch
@ -2289,7 +2333,7 @@ static bool llama_kv_cache_seq_rm(
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
// models like Mamba can't have a state partially erased // models like Mamba can't have a state partially erased
if (cache.unlimited) { if (cache.recurrent) {
if (seq_id >= (int64_t) cache.size) { if (seq_id >= (int64_t) cache.size) {
// could be fatal // could be fatal
return false; return false;
@ -2341,7 +2385,7 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) { if (cache.recurrent) {
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
seq_id_src = cache.cells[seq_id_src].src; seq_id_src = cache.cells[seq_id_src].src;
GGML_ASSERT((uint32_t) seq_id_src < cache.size); GGML_ASSERT((uint32_t) seq_id_src < cache.size);
@ -2403,7 +2447,7 @@ static void llama_kv_cache_seq_add(
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) { if (cache.recurrent) {
// for Mamba-like models, only the pos needs to be shifted // for Mamba-like models, only the pos needs to be shifted
if (0 <= seq_id && seq_id < (int64_t) cache.size) { if (0 <= seq_id && seq_id < (int64_t) cache.size) {
llama_kv_cell & cell = cache.cells[seq_id]; llama_kv_cell & cell = cache.cells[seq_id];
@ -2447,7 +2491,7 @@ static void llama_kv_cache_seq_div(
if (p0 < 0) p0 = 0; if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
if (cache.unlimited) { if (cache.recurrent) {
// for Mamba-like models, only the pos needs to be changed // for Mamba-like models, only the pos needs to be changed
if (0 <= seq_id && seq_id < (int64_t) cache.size) { if (0 <= seq_id && seq_id < (int64_t) cache.size) {
llama_kv_cell & cell = cache.cells[seq_id]; llama_kv_cell & cell = cache.cells[seq_id];
@ -3277,7 +3321,7 @@ static void llm_load_hparams(
// sanity check for n_rot (optional) // sanity check for n_rot (optional)
{ {
hparams.n_rot = hparams.n_embd / hparams.n_head; hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
@ -3290,10 +3334,10 @@ static void llm_load_hparams(
// gpt-j n_rot = rotary_dim // gpt-j n_rot = rotary_dim
} }
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head; hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head; hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// arch-specific KVs // arch-specific KVs
@ -3545,7 +3589,13 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
ml.get_key(LLM_KV_SSM_D_CONV, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_D_INNER, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_D_STATE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_DT_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 24: case 24:
switch (hparams.n_embd) { switch (hparams.n_embd) {
@ -3886,6 +3936,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
if (ml.n_elements >= 1e12) { if (ml.n_elements >= 1e12) {
@ -4050,10 +4104,7 @@ static bool llm_load_tensors(
const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_ff = hparams.n_ff; const int64_t n_ff = hparams.n_ff;
// Mamba uses these in its own way
if (model.arch != LLM_ARCH_MAMBA) {
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
}
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
@ -4792,12 +4843,11 @@ static bool llm_load_tensors(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
const int64_t d_conv = hparams.n_embd_head_k + 1; const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_state = hparams.n_embd_head_v; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_inner = hparams.n_head; const int64_t d_state = hparams.ssm_d_state;
// TODO: allow loading dt_rank from the model config const int64_t dt_rank = hparams.ssm_dt_rank;
// ceiling division // only an expansion factor of 2 is supported for now
const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0);
GGML_ASSERT(2 * n_embd == d_inner); GGML_ASSERT(2 * n_embd == d_inner);
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -5420,7 +5470,7 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps), norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens), n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n), n_kv (worst_case ? kv_self.size : kv_self.n),
kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx), n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type), pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type), rope_type (hparams.rope_type),
@ -5473,8 +5523,8 @@ struct llm_build_context {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size); ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size); ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
@ -8048,12 +8098,11 @@ struct llm_build_context {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t d_model = n_embd; const int64_t d_model = n_embd;
const int64_t d_inner = n_head; const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
GGML_ASSERT(2 * d_model == d_inner); GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_conv = n_embd_head_k + 1; const int64_t d_state = hparams.ssm_d_state;
const int64_t d_state = n_embd_head_v; const int64_t dt_rank = hparams.ssm_dt_rank;
// ceiling division
const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -8063,10 +8112,9 @@ struct llm_build_context {
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// (ab)using the kv cache to store the state // (ab)using the KV cache to store the states
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size); ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
// clear states of sequences which are starting at the beginning of this batch // clear states of sequences which are starting at the beginning of this batch
{ {
@ -8501,7 +8549,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
if (kv_self.unlimited) { if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
{ {
@ -8667,7 +8715,7 @@ static int llama_decode_internal(
return 1; return 1;
} }
if (!kv_self.unlimited) { if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
@ -9056,7 +9104,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
} }
} }
if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) { if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
llama_set_s_copy(lctx); llama_set_s_copy(lctx);
{ {
@ -12725,7 +12773,7 @@ struct llama_context * llama_new_context_with_model(
// graph inputs // graph inputs
{ {
ggml_init_params init_params = { ggml_init_params init_params = {
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)), /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)),
/* .mem_buffer */ nullptr, /* .mem_buffer */ nullptr,
/* .no_alloc */ true, /* .no_alloc */ true,
}; };
@ -12739,7 +12787,7 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
if (ctx->kv_self.unlimited) { if (ctx->kv_self.recurrent) {
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch); ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
@ -12753,7 +12801,7 @@ struct llama_context * llama_new_context_with_model(
ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_mean, "inp_mean"); ggml_set_name(ctx->inp_mean, "inp_mean");
ggml_set_name(ctx->inp_cls, "inp_cls"); ggml_set_name(ctx->inp_cls, "inp_cls");
if (ctx->kv_self.unlimited) { if (ctx->kv_self.recurrent) {
ggml_set_name(ctx->inp_s_copy, "inp_s_copy"); ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
ggml_set_name(ctx->inp_s_mask, "inp_s_mask"); ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
ggml_set_name(ctx->inp_s_seq, "inp_s_seq"); ggml_set_name(ctx->inp_s_seq, "inp_s_seq");