mamba : simplify the conv step with a self-overlapping view

Turns out the conv_state can be made smaller by one column.
Note that this breaks existing GGUFs of Mamba,
because the key_value_length field is tied to the conv_state size.

Convolution with a self-overlapping view is cool!
And it's much simpler than what I initially thought would be necessary
to make the convolution step work with more than 1 token at a time.

Next step is to make the SSM step work on batches of tokens too,
and thus I need to figure out a way to make a parallel selective scan
which will keep the ssm_state small and won't make it bigger
by a factor of (n_layer * batch_size).

* llama : fix Mamba KV self size wrongly displaying as f16 instead of f32

Relatedly, I also tried to see if other types than f32 worked for the states,
but they don't, because of the operators used.
It's probably better anyway to keep lots of precision there,
since the states are small anyway.
This commit is contained in:
Francis Couture-Harpin 2024-01-30 21:48:04 -05:00
parent 3f7233b62e
commit e9cc45ecae
2 changed files with 59 additions and 43 deletions

View file

@ -1858,10 +1858,12 @@ class MambaModel(Model):
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model) self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(d_inner) self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state
self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4)) # NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
# Since the first column of the conv_state is shifted out each time, it's not actually needed
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16)) self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_file_type(self.ftype)

View file

@ -2069,9 +2069,6 @@ static bool llama_kv_cache_init(
if (model.arch == LLM_ARCH_MAMBA) { if (model.arch == LLM_ARCH_MAMBA) {
// only one slot is needed for Mamba // only one slot is needed for Mamba
n_ctx = 1; n_ctx = 1;
// it's probably best to keep as much precision as possible for the states
ktype = GGML_TYPE_F32;
vtype = GGML_TYPE_F32;
} }
cache.has_shift = false; cache.has_shift = false;
@ -4681,7 +4678,7 @@ static bool llm_load_tensors(
} break; } break;
case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA:
{ {
const int64_t d_conv = hparams.n_embd_head_k; const int64_t d_conv = hparams.n_embd_head_k + 1;
const int64_t d_state = hparams.n_embd_head_v; const int64_t d_state = hparams.n_embd_head_v;
const int64_t d_inner = hparams.n_head; const int64_t d_inner = hparams.n_head;
// FIXME: ceiling instead of floor // FIXME: ceiling instead of floor
@ -7915,28 +7912,27 @@ struct llm_build_context {
struct ggml_cgraph * build_mamba() { struct ggml_cgraph * build_mamba() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const bool use_conv = batch.n_tokens > 1; const int32_t n_tok = batch.n_tokens;
GGML_ASSERT(use_conv == false); // TODO: implement
// hopefully the compiler does constant folding // hopefully the compiler does constant folding
const int64_t d_model = n_embd; const int64_t d_model = n_embd;
const int64_t d_inner = n_head; const int64_t d_inner = n_head;
GGML_ASSERT(2 * d_model == d_inner); GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_conv = n_embd_head_k; const int64_t d_conv = n_embd_head_k + 1;
const int64_t d_state = n_embd_head_v; const int64_t d_state = n_embd_head_v;
const int64_t dt_rank = d_model / 16; const int64_t dt_rank = d_model / 16;
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
// NOTE: not sure what's the difference between the sequence length and the batch size in the paper. // {n_embd, n_tok}
// {n_embd, batch}
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1); cb(inpL, "inp_embd", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// (ab)using the kv cache to store the state // (ab)using the kv cache to store the state
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner); // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner); ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
// norm // norm
@ -7945,33 +7941,43 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
// {n_embd, 2*d_inner} * {n_embd, batch} = {2*d_inner, batch} // {n_embd, 2*d_inner} * {n_embd, n_tok} => {2*d_inner, n_tok}
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
// split the above in two // split the above in two
// assuming it's contiguous // => {d_inner, n_tok}
// {d_inner, batch}
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
cur = x;
// conv // conv
{ {
// shift conv state left // concat last (d_conv - 1) columns of conv_state, and x
conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0);
// update last column // The following tensor is too big in order to avoid an assertion error when making an overlapping view.
// x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column) // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1)); // This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}
// which is around (d_conv-1) times as small as its current size.
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok);
const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il]))); conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
// unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float)
conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x));
// rearrange and sum // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
// no need to rearrange the conv_state, since it's already in the right shape ggml_build_forward_expand(gf,
// => {1, d_inner} ggml_cpy(ctx0,
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d)); ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
// => {d_inner, 1} ggml_view_tensor(ctx0, kv_self.k_l[il])));
x = ggml_transpose(ctx0, x);
// prepare convolution for all tokens in the batch with a self-overlapping view
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1)*d_inner*ggml_element_size(conv_x), 0);
// perform convolution
// => {1, d_inner, n_tok}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
// => {d_inner, n_tok, 1}
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
// bias // bias
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b); x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
@ -7981,23 +7987,24 @@ struct llm_build_context {
// ssm // ssm
{ {
// {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state} // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x); struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
// FIXME: handle batches of more than 1 token // split
struct ggml_tensor * dt = ggml_view_1d(ctx0, x_db, dt_rank, 0); struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1], 0);
struct ggml_tensor * B = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank); struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state)); struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
// {dt_rank} * {dt_rank, d_inner} = {1, d_inner} // {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt); dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b)); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
dt = ggml_soft_plus(ctx0, dt); dt = ggml_soft_plus(ctx0, dt);
// FIXME: support batches with more than 1 token
// => {d_state, d_inner} // => {d_state, d_inner}
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt)); struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));
// => {d_state, d_inner} // => {d_state, d_inner}
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt)); struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt);
// => {d_state, d_inner} // => {d_state, d_inner}
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x)); cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
@ -8012,7 +8019,7 @@ struct llm_build_context {
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x)); y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
// {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1} // {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
} }
@ -12327,8 +12334,15 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed); ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all; ctx->logits_all = params.logits_all;
const ggml_type type_k = params.type_k; ggml_type type_k = params.type_k;
const ggml_type type_v = params.type_v; ggml_type type_v = params.type_v;
// Mamba (mis)uses the KV cache to store its states
if (model->arch == LLM_ARCH_MAMBA) {
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);