mamba : stop abusing attention metadata
This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent"
This commit is contained in:
parent
8f605cfe0d
commit
206e8ee2b2
4 changed files with 128 additions and 49 deletions
|
@ -1857,7 +1857,14 @@ class MambaModel(Model):
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
d_model = self.hparams["d_model"]
|
d_model = self.hparams["d_model"]
|
||||||
|
d_conv = self.hparams.get("d_conv", 4)
|
||||||
d_inner = self.hparams.get("d_inner", 2 * d_model)
|
d_inner = self.hparams.get("d_inner", 2 * d_model)
|
||||||
|
d_state = self.hparams.get("d_state", 16)
|
||||||
|
# ceiling division
|
||||||
|
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||||
|
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||||
|
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
|
||||||
|
|
||||||
# Fail early for models which don't have a block expansion factor of 2
|
# Fail early for models which don't have a block expansion factor of 2
|
||||||
assert d_inner == 2 * d_model
|
assert d_inner == 2 * d_model
|
||||||
|
|
||||||
|
@ -1865,13 +1872,13 @@ class MambaModel(Model):
|
||||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||||
self.gguf_writer.add_embedding_length(d_model)
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||||
self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state
|
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||||
|
self.gguf_writer.add_ssm_conv_kernel_size(d_conv)
|
||||||
|
self.gguf_writer.add_ssm_inner_length(d_inner)
|
||||||
|
self.gguf_writer.add_ssm_state_length(d_state)
|
||||||
|
self.gguf_writer.add_ssm_dt_rank(dt_rank)
|
||||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
|
||||||
# NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
|
|
||||||
# Since the first column of the conv_state is shifted out each time, it's not actually needed
|
|
||||||
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
|
|
||||||
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
|
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
def write_tensors(self):
|
def write_tensors(self):
|
||||||
|
|
|
@ -61,6 +61,12 @@ class Keys:
|
||||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||||
|
|
||||||
|
class SSM:
|
||||||
|
CONV_KERNEL_SIZE = "{arch}.ssm.d_conv"
|
||||||
|
INNER_LENGTH = "{arch}.ssm.d_inner"
|
||||||
|
STATE_LENGTH = "{arch}.ssm.d_state"
|
||||||
|
DT_RANK = "{arch}.ssm.dt_rank"
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
MODEL = "tokenizer.ggml.model"
|
MODEL = "tokenizer.ggml.model"
|
||||||
LIST = "tokenizer.ggml.tokens"
|
LIST = "tokenizer.ggml.tokens"
|
||||||
|
@ -763,6 +769,12 @@ KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
|
||||||
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
|
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
|
||||||
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
|
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
|
||||||
|
|
||||||
|
# SSM
|
||||||
|
KEY_SSM_CONV_KERNEL_SIZE = Keys.SSM.CONV_KERNEL_SIZE
|
||||||
|
KEY_SSM_INNER_LENGTH = Keys.SSM.INNER_LENGTH
|
||||||
|
KEY_SSM_STATE_LENGTH = Keys.SSM.STATE_LENGTH
|
||||||
|
KEY_SSM_DT_RANK = Keys.SSM.DT_RANK
|
||||||
|
|
||||||
# tokenization
|
# tokenization
|
||||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||||
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
|
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
|
||||||
|
|
|
@ -382,6 +382,18 @@ class GGUFWriter:
|
||||||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_conv_kernel_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.CONV_KERNEL_SIZE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_inner_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.INNER_LENGTH.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_state_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.STATE_LENGTH.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_ssm_dt_rank(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.SSM.DT_RANK.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_tokenizer_model(self, model: str) -> None:
|
def add_tokenizer_model(self, model: str) -> None:
|
||||||
self.add_string(Keys.Tokenizer.MODEL, model)
|
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||||
|
|
||||||
|
|
134
llama.cpp
134
llama.cpp
|
@ -286,6 +286,11 @@ enum llm_kv {
|
||||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||||
|
|
||||||
|
LLM_KV_SSM_D_INNER,
|
||||||
|
LLM_KV_SSM_D_CONV,
|
||||||
|
LLM_KV_SSM_D_STATE,
|
||||||
|
LLM_KV_SSM_DT_RANK,
|
||||||
|
|
||||||
LLM_KV_TOKENIZER_MODEL,
|
LLM_KV_TOKENIZER_MODEL,
|
||||||
LLM_KV_TOKENIZER_LIST,
|
LLM_KV_TOKENIZER_LIST,
|
||||||
LLM_KV_TOKENIZER_TOKEN_TYPE,
|
LLM_KV_TOKENIZER_TOKEN_TYPE,
|
||||||
|
@ -344,6 +349,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||||
|
|
||||||
|
{ LLM_KV_SSM_D_CONV, "%s.ssm.d_conv" },
|
||||||
|
{ LLM_KV_SSM_D_INNER, "%s.ssm.d_inner"},
|
||||||
|
{ LLM_KV_SSM_D_STATE, "%s.ssm.d_state"},
|
||||||
|
{ LLM_KV_SSM_DT_RANK, "%s.ssm.dt_rank"},
|
||||||
|
|
||||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
|
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
|
||||||
|
@ -1638,6 +1648,12 @@ struct llama_hparams {
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
uint32_t n_yarn_orig_ctx;
|
uint32_t n_yarn_orig_ctx;
|
||||||
|
|
||||||
|
// for State Space Models
|
||||||
|
uint32_t ssm_d_conv = 0;
|
||||||
|
uint32_t ssm_d_inner = 0;
|
||||||
|
uint32_t ssm_d_state = 0;
|
||||||
|
uint32_t ssm_dt_rank = 0;
|
||||||
|
|
||||||
float f_clamp_kqv = 0.0f;
|
float f_clamp_kqv = 0.0f;
|
||||||
float f_max_alibi_bias = 0.0f;
|
float f_max_alibi_bias = 0.0f;
|
||||||
|
|
||||||
|
@ -1666,6 +1682,11 @@ struct llama_hparams {
|
||||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||||
|
|
||||||
|
if (this->ssm_d_conv != other.ssm_d_conv) return true;
|
||||||
|
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||||
|
if (this->ssm_d_state != other.ssm_d_state) return true;
|
||||||
|
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
|
||||||
|
|
||||||
const float EPSILON = 1e-9f;
|
const float EPSILON = 1e-9f;
|
||||||
|
|
||||||
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
|
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
|
||||||
|
@ -1677,6 +1698,9 @@ struct llama_hparams {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_gqa() const {
|
uint32_t n_gqa() const {
|
||||||
|
if (n_head_kv == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return n_head/n_head_kv;
|
return n_head/n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1687,6 +1711,18 @@ struct llama_hparams {
|
||||||
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
|
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
|
||||||
return n_embd_head_v * n_head_kv;
|
return n_embd_head_v * n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings
|
||||||
|
// corresponds to Mamba's conv_states size
|
||||||
|
// TODO: maybe support other convolution strides than 1
|
||||||
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||||
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings
|
||||||
|
// corresponds to Mamba's ssm_states size
|
||||||
|
return ssm_d_state * ssm_d_inner;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
|
@ -1804,8 +1840,8 @@ struct llama_kv_cache {
|
||||||
bool has_shift = false;
|
bool has_shift = false;
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
bool do_copy = false;
|
bool do_copy = false;
|
||||||
// with Mamba, a cell can hold the state for more than one past token
|
// with recurrent state models, a cell can hold the state for more than one past token
|
||||||
bool unlimited = false;
|
bool recurrent = false;
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
// Note: The value of head isn't only used to optimize searching
|
||||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||||
|
@ -2067,14 +2103,21 @@ static bool llama_kv_cache_init(
|
||||||
bool offload) {
|
bool offload) {
|
||||||
const struct llama_hparams & hparams = model.hparams;
|
const struct llama_hparams & hparams = model.hparams;
|
||||||
|
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int64_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
cache.has_shift = false;
|
cache.has_shift = false;
|
||||||
|
|
||||||
// for now, only Mamba can hold state for more than one past token per cell
|
// TODO: find a nicer way to add other recurrent model architectures
|
||||||
cache.unlimited = model.arch == LLM_ARCH_MAMBA;
|
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
|
||||||
|
|
||||||
|
// TODO: support mixed reccurent Transformer architectues
|
||||||
|
// NOTE: (!a || b) is a logical implication (a -> b)
|
||||||
|
GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
|
||||||
|
GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
|
||||||
|
GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
|
||||||
|
GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
|
||||||
|
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
cache.size = kv_size;
|
cache.size = kv_size;
|
||||||
|
@ -2086,7 +2129,8 @@ static bool llama_kv_cache_init(
|
||||||
cache.cells.clear();
|
cache.cells.clear();
|
||||||
cache.cells.resize(kv_size);
|
cache.cells.resize(kv_size);
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.recurrent) {
|
||||||
|
// init state copy sources
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
cache.cells[i].src = i;
|
cache.cells[i].src = i;
|
||||||
}
|
}
|
||||||
|
@ -2164,8 +2208,8 @@ static bool llama_kv_cache_find_slot(
|
||||||
const uint32_t n_ctx = cache.size;
|
const uint32_t n_ctx = cache.size;
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.recurrent) {
|
||||||
// For unlimited context architectures (like Mamba),
|
// For recurrent state architectures (like Mamba),
|
||||||
// each KV cache cell can store the state for a whole sequence.
|
// each KV cache cell can store the state for a whole sequence.
|
||||||
|
|
||||||
// starting point to find the minimum seq_id used in the batch
|
// starting point to find the minimum seq_id used in the batch
|
||||||
|
@ -2289,7 +2333,7 @@ static bool llama_kv_cache_seq_rm(
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
// models like Mamba can't have a state partially erased
|
// models like Mamba can't have a state partially erased
|
||||||
if (cache.unlimited) {
|
if (cache.recurrent) {
|
||||||
if (seq_id >= (int64_t) cache.size) {
|
if (seq_id >= (int64_t) cache.size) {
|
||||||
// could be fatal
|
// could be fatal
|
||||||
return false;
|
return false;
|
||||||
|
@ -2341,7 +2385,7 @@ static void llama_kv_cache_seq_cp(
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.recurrent) {
|
||||||
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
|
||||||
seq_id_src = cache.cells[seq_id_src].src;
|
seq_id_src = cache.cells[seq_id_src].src;
|
||||||
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
|
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
|
||||||
|
@ -2403,7 +2447,7 @@ static void llama_kv_cache_seq_add(
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.recurrent) {
|
||||||
// for Mamba-like models, only the pos needs to be shifted
|
// for Mamba-like models, only the pos needs to be shifted
|
||||||
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||||
llama_kv_cell & cell = cache.cells[seq_id];
|
llama_kv_cell & cell = cache.cells[seq_id];
|
||||||
|
@ -2447,7 +2491,7 @@ static void llama_kv_cache_seq_div(
|
||||||
if (p0 < 0) p0 = 0;
|
if (p0 < 0) p0 = 0;
|
||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
if (cache.unlimited) {
|
if (cache.recurrent) {
|
||||||
// for Mamba-like models, only the pos needs to be changed
|
// for Mamba-like models, only the pos needs to be changed
|
||||||
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
|
||||||
llama_kv_cell & cell = cache.cells[seq_id];
|
llama_kv_cell & cell = cache.cells[seq_id];
|
||||||
|
@ -3277,7 +3321,7 @@ static void llm_load_hparams(
|
||||||
|
|
||||||
// sanity check for n_rot (optional)
|
// sanity check for n_rot (optional)
|
||||||
{
|
{
|
||||||
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||||
|
|
||||||
|
@ -3290,10 +3334,10 @@ static void llm_load_hparams(
|
||||||
// gpt-j n_rot = rotary_dim
|
// gpt-j n_rot = rotary_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
hparams.n_embd_head_k = hparams.n_embd / hparams.n_head;
|
hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
|
||||||
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
|
||||||
|
|
||||||
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head;
|
hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
|
||||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||||
|
|
||||||
// arch-specific KVs
|
// arch-specific KVs
|
||||||
|
@ -3545,7 +3589,13 @@ static void llm_load_hparams(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
{
|
{
|
||||||
|
ml.get_key(LLM_KV_SSM_D_CONV, hparams.ssm_d_conv);
|
||||||
|
ml.get_key(LLM_KV_SSM_D_INNER, hparams.ssm_d_inner);
|
||||||
|
ml.get_key(LLM_KV_SSM_D_STATE, hparams.ssm_d_state);
|
||||||
|
ml.get_key(LLM_KV_SSM_DT_RANK, hparams.ssm_dt_rank);
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 24:
|
case 24:
|
||||||
switch (hparams.n_embd) {
|
switch (hparams.n_embd) {
|
||||||
|
@ -3886,6 +3936,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||||
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
|
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
|
||||||
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||||
|
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
||||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||||
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
|
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
|
||||||
if (ml.n_elements >= 1e12) {
|
if (ml.n_elements >= 1e12) {
|
||||||
|
@ -4050,10 +4104,7 @@ static bool llm_load_tensors(
|
||||||
const int64_t n_vocab_type = hparams.n_vocab_type;
|
const int64_t n_vocab_type = hparams.n_vocab_type;
|
||||||
const int64_t n_ff = hparams.n_ff;
|
const int64_t n_ff = hparams.n_ff;
|
||||||
|
|
||||||
// Mamba uses these in its own way
|
|
||||||
if (model.arch != LLM_ARCH_MAMBA) {
|
|
||||||
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
||||||
}
|
|
||||||
|
|
||||||
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
||||||
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
||||||
|
@ -4792,12 +4843,11 @@ static bool llm_load_tensors(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
{
|
{
|
||||||
const int64_t d_conv = hparams.n_embd_head_k + 1;
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
const int64_t d_state = hparams.n_embd_head_v;
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
const int64_t d_inner = hparams.n_head;
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
// TODO: allow loading dt_rank from the model config
|
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||||
// ceiling division
|
// only an expansion factor of 2 is supported for now
|
||||||
const int64_t dt_rank = (n_embd / 16) + (n_embd % 16 > 0);
|
|
||||||
GGML_ASSERT(2 * n_embd == d_inner);
|
GGML_ASSERT(2 * n_embd == d_inner);
|
||||||
|
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
@ -5420,7 +5470,7 @@ struct llm_build_context {
|
||||||
norm_rms_eps (hparams.f_norm_rms_eps),
|
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||||
n_tokens (batch.n_tokens),
|
n_tokens (batch.n_tokens),
|
||||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||||
kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||||
pooling_type (cparams.pooling_type),
|
pooling_type (cparams.pooling_type),
|
||||||
rope_type (hparams.rope_type),
|
rope_type (hparams.rope_type),
|
||||||
|
@ -5473,8 +5523,8 @@ struct llm_build_context {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
|
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
|
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||||
|
|
||||||
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
|
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
|
||||||
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
|
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
|
||||||
|
@ -8048,12 +8098,11 @@ struct llm_build_context {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
const int64_t d_model = n_embd;
|
const int64_t d_model = n_embd;
|
||||||
const int64_t d_inner = n_head;
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
GGML_ASSERT(2 * d_model == d_inner);
|
GGML_ASSERT(2 * d_model == d_inner);
|
||||||
const int64_t d_conv = n_embd_head_k + 1;
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
const int64_t d_state = n_embd_head_v;
|
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||||
// ceiling division
|
|
||||||
const int64_t dt_rank = (d_model / 16) + (d_model % 16 > 0);
|
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
struct ggml_tensor * inpL;
|
struct ggml_tensor * inpL;
|
||||||
|
@ -8063,10 +8112,9 @@ struct llm_build_context {
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
// (ab)using the kv cache to store the state
|
// (ab)using the KV cache to store the states
|
||||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
|
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);
|
|
||||||
|
|
||||||
// clear states of sequences which are starting at the beginning of this batch
|
// clear states of sequences which are starting at the beginning of this batch
|
||||||
{
|
{
|
||||||
|
@ -8501,7 +8549,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_self.unlimited) {
|
if (kv_self.recurrent) {
|
||||||
const int64_t n_kv = kv_self.n;
|
const int64_t n_kv = kv_self.n;
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -8667,7 +8715,7 @@ static int llama_decode_internal(
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_self.unlimited) {
|
if (!kv_self.recurrent) {
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
|
@ -9056,7 +9104,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
|
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
|
||||||
llama_set_s_copy(lctx);
|
llama_set_s_copy(lctx);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -12725,7 +12773,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
// graph inputs
|
// graph inputs
|
||||||
{
|
{
|
||||||
ggml_init_params init_params = {
|
ggml_init_params init_params = {
|
||||||
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)),
|
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)),
|
||||||
/* .mem_buffer */ nullptr,
|
/* .mem_buffer */ nullptr,
|
||||||
/* .no_alloc */ true,
|
/* .no_alloc */ true,
|
||||||
};
|
};
|
||||||
|
@ -12739,7 +12787,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
||||||
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||||
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||||
if (ctx->kv_self.unlimited) {
|
if (ctx->kv_self.recurrent) {
|
||||||
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
|
||||||
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
|
||||||
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
|
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
|
||||||
|
@ -12753,7 +12801,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
||||||
ggml_set_name(ctx->inp_mean, "inp_mean");
|
ggml_set_name(ctx->inp_mean, "inp_mean");
|
||||||
ggml_set_name(ctx->inp_cls, "inp_cls");
|
ggml_set_name(ctx->inp_cls, "inp_cls");
|
||||||
if (ctx->kv_self.unlimited) {
|
if (ctx->kv_self.recurrent) {
|
||||||
ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
|
ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
|
||||||
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
|
||||||
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
|
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue