feat(bamba): Full tensor parsing for bamba
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
fd3bb30118
commit
3ee0ae3b90
1 changed files with 91 additions and 13 deletions
104
src/llama.cpp
104
src/llama.cpp
|
@ -336,8 +336,6 @@ enum llm_kv {
|
||||||
LLM_KV_SSM_HEAD_COUNT,
|
LLM_KV_SSM_HEAD_COUNT,
|
||||||
LLM_KV_SSM_HEAD_DIM,
|
LLM_KV_SSM_HEAD_DIM,
|
||||||
LLM_KV_SSM_CHUNK_SIZE,
|
LLM_KV_SSM_CHUNK_SIZE,
|
||||||
LLM_KV_SSM_CONV_BIAS,
|
|
||||||
LLM_KV_SSM_PROJ_BIAS,
|
|
||||||
|
|
||||||
LLM_KV_WKV_HEAD_SIZE,
|
LLM_KV_WKV_HEAD_SIZE,
|
||||||
|
|
||||||
|
@ -459,8 +457,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_SSM_HEAD_COUNT, "%s.ssm.head_count" },
|
{ LLM_KV_SSM_HEAD_COUNT, "%s.ssm.head_count" },
|
||||||
{ LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" },
|
{ LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" },
|
||||||
{ LLM_KV_SSM_CHUNK_SIZE, "%s.ssm.chunk_size" },
|
{ LLM_KV_SSM_CHUNK_SIZE, "%s.ssm.chunk_size" },
|
||||||
{ LLM_KV_SSM_CONV_BIAS, "%s.ssm.conv_bias" },
|
|
||||||
{ LLM_KV_SSM_PROJ_BIAS, "%s.ssm.proj_bias" },
|
|
||||||
|
|
||||||
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
||||||
|
|
||||||
|
@ -2493,8 +2489,6 @@ struct llama_hparams {
|
||||||
uint32_t ssm_head_count = 0;
|
uint32_t ssm_head_count = 0;
|
||||||
uint32_t ssm_head_dim = 0;
|
uint32_t ssm_head_dim = 0;
|
||||||
uint32_t ssm_chunk_size = 0;
|
uint32_t ssm_chunk_size = 0;
|
||||||
bool ssm_conv_bias = false;
|
|
||||||
bool ssm_proj_bias = false;
|
|
||||||
|
|
||||||
// for hybrid state space models
|
// for hybrid state space models
|
||||||
std::array<bool, LLAMA_MAX_LAYERS> ssm_layer_arr;
|
std::array<bool, LLAMA_MAX_LAYERS> ssm_layer_arr;
|
||||||
|
@ -2557,8 +2551,6 @@ struct llama_hparams {
|
||||||
if (this->ssm_head_count != other.ssm_head_count) return true;
|
if (this->ssm_head_count != other.ssm_head_count) return true;
|
||||||
if (this->ssm_head_dim != other.ssm_head_dim) return true;
|
if (this->ssm_head_dim != other.ssm_head_dim) return true;
|
||||||
if (this->ssm_chunk_size != other.ssm_chunk_size) return true;
|
if (this->ssm_chunk_size != other.ssm_chunk_size) return true;
|
||||||
if (this->ssm_conv_bias != other.ssm_conv_bias) return true;
|
|
||||||
if (this->ssm_proj_bias != other.ssm_proj_bias) return true;
|
|
||||||
|
|
||||||
if (this->ssm_layer_arr != other.ssm_layer_arr) return true;
|
if (this->ssm_layer_arr != other.ssm_layer_arr) return true;
|
||||||
|
|
||||||
|
@ -2800,6 +2792,7 @@ struct llama_layer {
|
||||||
// mamba bias
|
// mamba bias
|
||||||
struct ggml_tensor * ssm_conv1d_b;
|
struct ggml_tensor * ssm_conv1d_b;
|
||||||
struct ggml_tensor * ssm_dt_b;
|
struct ggml_tensor * ssm_dt_b;
|
||||||
|
struct ggml_tensor * ssm_in_b;
|
||||||
|
|
||||||
// rwkv
|
// rwkv
|
||||||
struct ggml_tensor * time_mix_w1;
|
struct ggml_tensor * time_mix_w1;
|
||||||
|
@ -6004,8 +5997,6 @@ static void llm_load_hparams(
|
||||||
ml.get_key(LLM_KV_SSM_HEAD_COUNT, hparams.ssm_head_count);
|
ml.get_key(LLM_KV_SSM_HEAD_COUNT, hparams.ssm_head_count);
|
||||||
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
|
ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim);
|
||||||
ml.get_key(LLM_KV_SSM_CHUNK_SIZE, hparams.ssm_chunk_size);
|
ml.get_key(LLM_KV_SSM_CHUNK_SIZE, hparams.ssm_chunk_size);
|
||||||
ml.get_key(LLM_KV_SSM_CONV_BIAS, hparams.ssm_conv_bias);
|
|
||||||
ml.get_key(LLM_KV_SSM_PROJ_BIAS, hparams.ssm_proj_bias);
|
|
||||||
|
|
||||||
// Attention params
|
// Attention params
|
||||||
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
|
std::fill(hparams.ssm_layer_arr.begin(), hparams.ssm_layer_arr.end(), true);
|
||||||
|
@ -7100,8 +7091,6 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||||
LLAMA_LOG_INFO("%s: ssm_head_count = %d\n", __func__, hparams.ssm_head_count);
|
LLAMA_LOG_INFO("%s: ssm_head_count = %d\n", __func__, hparams.ssm_head_count);
|
||||||
LLAMA_LOG_INFO("%s: ssm_head_dim = %d\n", __func__, hparams.ssm_head_dim);
|
LLAMA_LOG_INFO("%s: ssm_head_dim = %d\n", __func__, hparams.ssm_head_dim);
|
||||||
LLAMA_LOG_INFO("%s: ssm_chunk_size = %d\n", __func__, hparams.ssm_chunk_size);
|
LLAMA_LOG_INFO("%s: ssm_chunk_size = %d\n", __func__, hparams.ssm_chunk_size);
|
||||||
LLAMA_LOG_INFO("%s: ssm_conv_bias = %d\n", __func__, hparams.ssm_conv_bias);
|
|
||||||
LLAMA_LOG_INFO("%s: ssm_proj_bias = %d\n", __func__, hparams.ssm_proj_bias);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||||
|
@ -7761,6 +7750,12 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
model.layers.resize(n_layer);
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
|
// Log out tensor names for verbose debugging
|
||||||
|
LLAMA_LOG_DEBUG("%s: TENSORS\n", __func__);
|
||||||
|
for (const auto& entry : ml.weights_map) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: %s\n", __func__, entry.first.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: move to a separate function
|
// TODO: move to a separate function
|
||||||
const auto tn = LLM_TN(model.arch);
|
const auto tn = LLM_TN(model.arch);
|
||||||
switch (model.arch) {
|
switch (model.arch) {
|
||||||
|
@ -8740,6 +8735,83 @@ static bool llm_load_tensors(
|
||||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BAMBA:
|
||||||
|
{
|
||||||
|
// mamba2 Mixer SSM params
|
||||||
|
// TODO: Why are these int64_t and not uint32_t?
|
||||||
|
const int64_t d_conv = hparams.ssm_d_conv;
|
||||||
|
const int64_t d_inner = hparams.ssm_d_inner;
|
||||||
|
const int64_t d_state = hparams.ssm_d_state;
|
||||||
|
const int64_t n_group = hparams.ssm_n_group;
|
||||||
|
const int64_t head_count = hparams.ssm_head_count;
|
||||||
|
const int64_t head_dim = hparams.ssm_head_dim;
|
||||||
|
const int64_t chunk_size = hparams.ssm_chunk_size;
|
||||||
|
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + head_count;
|
||||||
|
|
||||||
|
// only an expansion factor of 2 is supported for now
|
||||||
|
GGML_ASSERT(2 * n_embd == d_inner);
|
||||||
|
|
||||||
|
// embeddings
|
||||||
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||||
|
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
// if output is NULL, init from the input tok embed, duplicated to allow offloading
|
||||||
|
if (model.output == NULL) {
|
||||||
|
model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
// norm
|
||||||
|
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
|
if (hparams.ssm_layer(i)) {
|
||||||
|
// ssm layers
|
||||||
|
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
|
||||||
|
layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, d_in_proj}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
||||||
|
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {head_count}, 0);
|
||||||
|
|
||||||
|
// no "weight" suffix for these
|
||||||
|
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, head_count}, 0);
|
||||||
|
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, head_count}, 0);
|
||||||
|
|
||||||
|
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
||||||
|
|
||||||
|
// out_proj
|
||||||
|
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
||||||
|
} else {
|
||||||
|
// attention layers (with optional bias)
|
||||||
|
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||||
|
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||||
|
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||||
|
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||||
|
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// feed forward (w/ optional biases)
|
||||||
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||||
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||||
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||||
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||||
|
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLM_ARCH_XVERSE:
|
case LLM_ARCH_XVERSE:
|
||||||
{
|
{
|
||||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||||
|
@ -9546,7 +9618,12 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||||
|
|
||||||
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
|
||||||
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
|
||||||
throw std::runtime_error("vocab size mismatch");
|
std::stringstream ss;
|
||||||
|
ss << "vocab size mismatch. "
|
||||||
|
<< model.hparams.n_vocab
|
||||||
|
<< " != "
|
||||||
|
<< model.vocab.id_to_token.size();
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.vocab_only) {
|
if (params.vocab_only) {
|
||||||
|
@ -20407,6 +20484,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
case LLM_ARCH_GRANITE:
|
case LLM_ARCH_GRANITE:
|
||||||
case LLM_ARCH_GRANITE_MOE:
|
case LLM_ARCH_GRANITE_MOE:
|
||||||
case LLM_ARCH_CHAMELEON:
|
case LLM_ARCH_CHAMELEON:
|
||||||
|
case LLM_ARCH_BAMBA:
|
||||||
return LLAMA_ROPE_TYPE_NORM;
|
return LLAMA_ROPE_TYPE_NORM;
|
||||||
|
|
||||||
// the pairs of head values are offset by n_rot/2
|
// the pairs of head values are offset by n_rot/2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue