mamba : recurrent inference almost works, but incoherent

This commit is contained in:
Francis Couture-Harpin 2024-01-28 15:36:42 -05:00
parent 5a69a262a1
commit f680364bd8
3 changed files with 128 additions and 41 deletions

View file

@ -1851,13 +1851,57 @@ class MambaModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
d_model = self.hparams["d_model"] d_model = self.hparams["d_model"]
self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(128) # arbitrary value; it shouldn't be important for Mamba
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(2 * d_model) # d_inner self.gguf_writer.add_head_count(2 * d_model) # d_inner
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
self.gguf_writer.add_key_length(4) # d_conv self.gguf_writer.add_key_length(4) # d_conv
self.gguf_writer.add_value_length(16) # d_state self.gguf_writer.add_value_length(16) # d_state
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self):
block_count = self.hparams["n_layer"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
if name.endswith(".A_log"):
print("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data = data_torch.squeeze().numpy()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert big float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######

4
ggml.c
View file

@ -5331,7 +5331,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
// ggml_soft_plus // ggml_soft_plus
struct ggml_tensor * ggml_soft_plus_impl( static struct ggml_tensor * ggml_soft_plus_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
bool inplace) { bool inplace) {
@ -15737,7 +15737,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
case GGML_OP_SOFT_PLUS: case GGML_OP_SOFT_PLUS:
{ {
ggml_compute_forward_soft_plus(params, tensor->src[0], tensor); ggml_compute_forward_soft_plus(params, tensor->src[0], tensor);
} } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
{ {
ggml_compute_forward_rope(params, tensor); ggml_compute_forward_rope(params, tensor);

119
llama.cpp
View file

@ -1765,7 +1765,6 @@ struct llama_layer {
struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_up_b; // b3
struct ggml_tensor * ffn_act; struct ggml_tensor * ffn_act;
// mamba proj // mamba proj
struct ggml_tensor * ssm_in; struct ggml_tensor * ssm_in;
struct ggml_tensor * ssm_x; struct ggml_tensor * ssm_x;
@ -3435,6 +3434,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 24: case 24:
switch (hparams.n_embd) { switch (hparams.n_embd) {
@ -3455,7 +3455,7 @@ static void llm_load_hparams(
} break; } break;
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} } break;
default: (void)0; default: (void)0;
} }
@ -3939,7 +3939,10 @@ static bool llm_load_tensors(
const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_vocab_type = hparams.n_vocab_type;
const int64_t n_ff = hparams.n_ff; const int64_t n_ff = hparams.n_ff;
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); // Mamba uses these in its own way
if (model.arch != LLM_ARCH_MAMBA) {
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
}
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
@ -4678,19 +4681,21 @@ static bool llm_load_tensors(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
const int64_t d_conv = hparams.n_embd_head_k; const int64_t d_conv = hparams.n_embd_head_k;
const int64_t d_state = hparams.n_embd_head_v; const int64_t d_state = hparams.n_embd_head_v;
const int64_t d_inner = hparams.n_head; const int64_t d_inner = hparams.n_head;
// FIXME: ceiling instead of floor // FIXME: ceiling instead of floor
const int64_t dt_rank = n_embd / 16; const int64_t dt_rank = n_embd / 16;
GGML_ASSERT(2 * n_embd == d_inner); GGML_ASSERT(2 * n_embd == d_inner);
// round up the vocab size to the next multiple of 8
const int64_t rounded_vocab = (n_vocab + 7) & -8;
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, rounded_vocab});
// output // output
{ {
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, rounded_vocab});
} }
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
@ -4707,7 +4712,7 @@ static bool llm_load_tensors(
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, 1, d_inner}); layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
@ -4715,9 +4720,9 @@ static bool llm_load_tensors(
layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
// FIXME: maybe no suffix for these // no "weight" suffix for these
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, "weight", i), {d_state, d_inner}); layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, "weight", i), {d_inner}); layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
// out_proj // out_proj
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
@ -7907,16 +7912,18 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_mamba(bool use_conv) { struct ggml_cgraph * build_mamba() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const bool use_conv = batch.n_tokens > 1;
GGML_ASSERT(use_conv == false); // TODO: implement GGML_ASSERT(use_conv == false); // TODO: implement
const int64_t d_model = hparams.n_embd; // hopefully the compiler does constant folding
const int64_t d_inner = hparams.n_head; const int64_t d_model = n_embd;
const int64_t d_inner = n_head;
GGML_ASSERT(2 * d_model == d_inner); GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_conv = hparams.n_embd_head_k; const int64_t d_conv = n_embd_head_k;
const int64_t d_state = hparams.n_embd_head_v; const int64_t d_state = n_embd_head_v;
const int64_t dt_rank = d_model / 16; const int64_t dt_rank = d_model / 16;
struct ggml_tensor * cur; struct ggml_tensor * cur;
@ -7928,8 +7935,10 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// (ab)using the kv cache to store the state // (ab)using the kv cache to store the state
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_conv, d_inner} // NOTE: the conv_state is transposed to ease shifting it.
ggml_tensor * ssm_state = kv_self.v_l[il]; // {d_state, d_inner} // if you figured out a way to shift it without transposing it like this, go ahead and fix this.
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_inner, d_conv}
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
// norm // norm
cur = llm_build_norm(ctx0, inpL, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
@ -7941,36 +7950,73 @@ struct llm_build_context {
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner} // {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in); struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
// split the above in two // split the above in two
// assuming it's contiguous
// FIXME: handle batches of more than 1 token
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0); struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner); struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, ggml_element_size(xz)*d_inner);
cur = x;
// FIXME: figure out when to transpose
// conv // conv
{ {
// TODO: figure out how to do a row-wise dot product // shift conv state left
// TODO: use the kv-cache to store the state conv_state = ggml_set_1d(ctx0, conv_state, ggml_view_1d(ctx0, conv_state, (d_conv - 1)*d_inner, ggml_element_size(conv_state)*d_inner), 0);
kv_self.k_l[il];
// FIXME: this is wrong // update last column
cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1); conv_state = ggml_set_1d(ctx0, conv_state, x, ggml_element_size(conv_state)*(d_conv - 1)*d_inner);
cur = ggml_add(ctx0, cur, model.layers[il].ssm_conv1d_b); ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
// TODO: there's some SiLU in there (but no ffn? or is the conv an ffn?) // rearrange and sum
cur = ggml_silu(ctx0, cur); conv_state = ggml_reshape_2d(ctx0, conv_state, d_inner, d_conv);
// TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
conv_state = ggml_cont(ctx0, ggml_transpose(ctx0, conv_state));
// --> {1, d_inner}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
x = ggml_transpose(ctx0, x);
// bias
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
x = ggml_silu(ctx0, x);
} }
// ssm // ssm
{ {
// {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x);
// FIXME: handle batches of more than 1 token
struct ggml_tensor * dt = ggml_view_1d(ctx0, x_db, dt_rank, 0);
struct ggml_tensor * B = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state));
// TODO: use ggml_soft_plus here // {dt_rank} * {dt_rank, d_inner} = {1, d_inner}
dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt);
dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b));
dt = ggml_soft_plus(ctx0, dt);
} // => {d_state, d_inner}
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt));
// TODO: there's some SiLU again towards the end. Can the `llm_build_ffn` helper be used? // => {d_state, d_inner}
// Maybe the best way is to implement it, _then_ check if that helper would do the same thing. struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt));
// discretize
{ // => {d_state, d_inner}
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_state, ggml_view_tensor(ctx0, kv_self.v_l[il])));
// row-wise dot product ("dn,n->d")
// {d_state, d_inner} * {d_state} => {d_inner, 1}
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, C);
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
// {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
} }
// residual // residual
@ -7981,11 +8027,8 @@ struct llm_build_context {
inpL = cur; inpL = cur;
} }
// the last step of each layer already makes these equivalent
// cur = inpL;
// final rmsnorm // final rmsnorm
cur = llm_build_norm(ctx0, cur, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
model.output_norm, NULL, model.output_norm, NULL,
LLM_NORM_RMS, cb, -1); LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1); cb(cur, "result_norm", -1);
@ -8150,7 +8193,7 @@ static struct ggml_cgraph * llama_build_graph(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
result = llm.build_mamba(/* use_conv =*/ batch.n_tokens > 1); result = llm.build_mamba();
} break; } break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);