mamba : recurrent inference WORKS!!!

This commit is contained in:
Francis Couture-Harpin 2024-01-28 16:20:03 -05:00
parent f680364bd8
commit 54d3e48601
2 changed files with 1 additions and 4 deletions

4
ggml.c
View file

@ -5336,8 +5336,6 @@ static struct ggml_tensor * ggml_soft_plus_impl(
struct ggml_tensor * a, struct ggml_tensor * a,
bool inplace) { bool inplace) {
// TODO: does `a` need to be contiguous?
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {
@ -12190,7 +12188,7 @@ static void ggml_compute_forward_soft_plus_f32(
float * x = (float *) ((char *) dst->data + i*( dst->nb[1])); float * x = (float *) ((char *) dst->data + i*( dst->nb[1]));
float * y = (float *) ((char *) src0->data + i*(src0->nb[1])); float * y = (float *) ((char *) src0->data + i*(src0->nb[1]));
for (int j = 0; j < nc; ++j) { for (int j = 0; j < nc; ++j) {
x[j] = logf(1.0f + expf(y[i])); x[j] = logf(1.0f + expf(y[j]));
} }
} }
} }

View file

@ -7944,7 +7944,6 @@ struct llm_build_context {
cur = llm_build_norm(ctx0, inpL, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL, model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
// TODO: that's probably the wrong name.
cb(cur, "attn_norm", il); cb(cur, "attn_norm", il);
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner} // {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}