From 0784a0cf2634b673bc0a47c1b08f3e8e9f630825 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Fri, 2 Aug 2024 13:58:34 +0800 Subject: [PATCH] RWKV v6 graph building Signed-off-by: Molly Sophia --- ggml/include/ggml.h | 10 ++ ggml/src/ggml.c | 149 ++++++++++++++++++++++- src/llama.cpp | 289 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 423 insertions(+), 25 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8ea652dc8..39aff9e39 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -512,6 +512,7 @@ extern "C" { GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, + GGML_OP_RWKV_WKV, GGML_OP_UNARY, @@ -1896,6 +1897,15 @@ extern "C" { struct ggml_tensor * pw, struct ggml_tensor * ph); + GGML_API struct ggml_tensor * ggml_rwkv_wkv( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f7d016dad..78f1aa3b2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2835,6 +2835,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "WIN_UNPART", "GET_REL_POS", "ADD_REL_POS", + "RWKV_WKV", "UNARY", @@ -2853,7 +2854,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2927,6 +2928,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", + "rwkv_wkv(x, k, v, r, tf, td, s)", "unary(x)", @@ -2945,7 +2947,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -7637,6 +7639,57 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } +// ggml_rwkv_wkv + +struct ggml_tensor * ggml_rwkv_wkv( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(r)); + GGML_ASSERT(ggml_is_contiguous(tf)); + GGML_ASSERT(ggml_is_contiguous(td)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[2]; + const int64_t n_tokens = k->ne[3]; + { + GGML_ASSERT(k->ne[1] == 1); + GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); + GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); + // TODO: RWKV v4 and v5 + GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H); + } + + bool is_node = false; + + if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) { + GGML_ABORT("fatal error"); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { S * H, n_tokens, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_RWKV_WKV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = k; + result->src[1] = v; + result->src[2] = r; + result->src[3] = tf; + result->src[4] = td; + result->src[5] = state; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl( @@ -16795,6 +16848,92 @@ static void ggml_compute_forward_add_rel_pos( } } +// ggml_compute_forward_rwkv_wkv + +static void ggml_compute_forward_rwkv_wkv_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const size_t T = dst->ne[1]; + const size_t C = dst->ne[0]; + const size_t H = dst->src[1]->ne[2]; + + float * dst_data = (float *) dst->data; + + if (params->ith != 0) { + return; + } + + memset(dst_data, 0, T * C * sizeof(float)); + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * r = (float *) dst->src[2]->data; + float * time_faaaa = (float *) dst->src[3]->data; + float * time_decay = (float *) dst->src[4]->data; + float * state = (float *) dst->src[5]->data; + + size_t t_stride = H * (C / H); + + size_t h_stride = C / H; + size_t h_stride_2d = (C / H) * (C / H); + + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + + for (size_t h = 0; h < H; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (size_t i = 0; i < C / H; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; + + for (size_t j = 0; j < C / H; j ++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } +} + +static void ggml_compute_forward_rwkv_wkv( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rwkv_wkv_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -17446,6 +17585,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add_rel_pos(params, tensor); } break; + case GGML_OP_RWKV_WKV: + { + ggml_compute_forward_rwkv_wkv(params, tensor); + } break; case GGML_OP_MAP_UNARY: { ggml_unary_op_f32_t fun; @@ -18569,6 +18712,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_GET_REL_POS: case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: @@ -19143,6 +19287,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: + case GGML_OP_RWKV_WKV: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/src/llama.cpp b/src/llama.cpp index 818e34776..c43776acd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3512,7 +3512,7 @@ static bool llama_kv_cache_find_slot( const uint32_t n_seq_tokens = batch.n_seq_tokens; if (cache.recurrent) { - // For recurrent state architectures (like Mamba), + // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. // A slot should be always be contiguous. @@ -3761,7 +3761,7 @@ static bool llama_kv_cache_seq_rm( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - // models like Mamba can't have a state partially erased + // models like Mamba or RWKV can't have a state partially erased if (cache.recurrent) { if (seq_id >= (int64_t) cache.size) { // could be fatal @@ -3897,7 +3897,7 @@ static void llama_kv_cache_seq_add( if (p0 == p1) return; if (cache.recurrent) { - // for Mamba-like models, only the pos needs to be shifted + // for Mamba-like or RWKV models, only the pos needs to be shifted if (0 <= seq_id && seq_id < (int64_t) cache.size) { const int32_t tail_id = cache.cells[seq_id].tail; if (tail_id >= 0) { @@ -3946,7 +3946,7 @@ static void llama_kv_cache_seq_div( if (p0 == p1) return; if (cache.recurrent) { - // for Mamba-like models, only the pos needs to be changed + // for Mamba-like or RWKV models, only the pos needs to be changed if (0 <= seq_id && seq_id < (int64_t) cache.size) { const int32_t tail_id = cache.cells[seq_id].tail; if (tail_id >= 0) { @@ -5885,8 +5885,9 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); // TODO: Re-using mamba keys right now, but RWKV isn't state-space - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); } break; default: (void)0; } @@ -8323,7 +8324,7 @@ static bool llm_load_tensors( // Block 0, LN0 model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); // output model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); @@ -8348,8 +8349,8 @@ static bool llm_load_tensors( layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}); layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}); - layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {time_mix_extra_dim * 5, n_embd}); - layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_embd, time_mix_extra_dim, 5}); + layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}); + layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}); layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}); layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}); @@ -8361,8 +8362,8 @@ static bool llm_load_tensors( // TODO: Parametrize hardcoded dimensions for first & decay layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}); layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}); - layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {time_decay_extra_dim, n_embd}); - layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {attn_hidden_size, time_decay_extra_dim}); + layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}); + layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}); layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}); layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}); layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}); @@ -9344,9 +9345,198 @@ static struct ggml_tensor * llm_build_time_mix( struct ggml_context * ctx, const struct llama_layer * layer, struct ggml_tensor * current, - int layer_i) { + struct ggml_tensor * x_prev, + struct ggml_tensor * wkv_state) { + size_t n_embed = current->ne[0]; + size_t n_tokens = current->ne[1]; + size_t head_size = layer->time_mix_first->ne[0]; + size_t head_count = layer->time_mix_first->ne[1]; - return current; + struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current); + struct ggml_tensor * xxx = ggml_add_inplace( + ctx, + ggml_mul(ctx, sx, layer->time_mix_lerp_x), + current + ); + + xxx = ggml_reshape_4d( + ctx, + ggml_tanh_inplace( + ctx, + ggml_mul_mat(ctx, layer->time_mix_w1, xxx) + ), + layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens + ); + + xxx = ggml_cont( + ctx, + ggml_permute(ctx, xxx, 0, 1, 3, 2) + ); + + xxx = ggml_mul_mat( + ctx, + ggml_reshape_4d( + ctx, + layer->time_mix_w2, + layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5 + ), + xxx + ); + + struct ggml_tensor *mw = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); + mw = ggml_reshape_2d( + ctx, + ggml_set_1d(ctx, mw, ggml_view_1d(ctx, xxx, n_embed * n_tokens, 0), 0), + n_embed, n_tokens + ); + + struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); + mk = ggml_reshape_2d( + ctx, + ggml_set_1d_inplace(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0), + n_embed, n_tokens + ); + + struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); + mv = ggml_reshape_2d( + ctx, + ggml_set_1d_inplace(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0), + n_embed, n_tokens + ); + + struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); + mr = ggml_reshape_2d( + ctx, + ggml_set_1d_inplace(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0), + n_embed, n_tokens + ); + + struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); + mg = ggml_reshape_2d( + ctx, + ggml_set_1d_inplace(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0), + n_embed, n_tokens + ); + + struct ggml_tensor * xw = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add(ctx, mw, layer->time_mix_lerp_w), + sx + ), + current + ); + + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add(ctx, mk, layer->time_mix_lerp_k), + sx + ), + current + ); + + struct ggml_tensor * xv = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add(ctx, mv, layer->time_mix_lerp_v), + sx + ), + current + ); + + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add(ctx, mr, layer->time_mix_lerp_r), + sx + ), + current + ); + + struct ggml_tensor * xg = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + ggml_add(ctx, mg, layer->time_mix_lerp_g), + sx + ), + current + ); + + struct ggml_tensor * r = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens); + struct ggml_tensor * k = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens); + struct ggml_tensor * v = ggml_reshape_4d(ctx, ggml_mul_mat(ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens); + struct ggml_tensor * g = ggml_silu_inplace( + ctx, + ggml_mul_mat(ctx, layer->time_mix_gate, xg) + ); + + struct ggml_tensor * w = ggml_mul_mat( + ctx, + layer->time_mix_decay_w2, + ggml_tanh_inplace( + ctx, + ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw) + ) + ); + w = ggml_add_inplace( + ctx, + w, + ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed) + ); + w = ggml_exp(ctx, ggml_neg_inplace(ctx, ggml_exp(ctx, w))); + w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); + + k = ggml_transpose(ctx, k); + v = ggml_transpose(ctx, v); + r = ggml_transpose(ctx, r); + current = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, wkv_state); + + // ggml_group_norm considers groups in the third dimension. + current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens); + current = ggml_group_norm(ctx, current, head_count, 64e-5f); + // Convert back to a regular vector. + current = ggml_reshape_2d(ctx, current, n_embed, n_tokens); + current = ggml_add_inplace( + ctx, + ggml_mul_inplace( + ctx, + current, + layer->time_mix_ln + ), + layer->time_mix_ln_b + ); + + current = ggml_mul(ctx, current, g); + + return ggml_mul_mat(ctx, layer->time_mix_output, current); +} + +static struct ggml_tensor * llm_build_channel_mix( + struct ggml_context * ctx, + const struct llama_layer * layer, + struct ggml_tensor * current, + struct ggml_tensor * x_prev) { + + struct ggml_tensor * sx = ggml_sub(ctx, x_prev, current); + struct ggml_tensor * xk = ggml_add_inplace( + ctx, + ggml_mul(ctx, sx, layer->channel_mix_lerp_k), + current + ); + struct ggml_tensor * xr = ggml_add_inplace( + ctx, + ggml_mul(ctx, sx, layer->channel_mix_lerp_r), + current + ); + struct ggml_tensor * r = ggml_sigmoid_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_receptance, xr)); + struct ggml_tensor * k = ggml_sqr_inplace(ctx, ggml_relu_inplace(ctx, ggml_mul_mat(ctx, layer->channel_mix_key, xk))); + return ggml_mul_inplace(ctx, r, ggml_mul_mat(ctx, layer->channel_mix_value, k)); } struct llm_build_context { @@ -14874,32 +15064,85 @@ struct llm_build_context { ggml_cgraph * build_rwkv() { ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + // Token shift state dimensions should be 2 * n_emb + GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); + // Input embeddings, start of the model after tokenizing ({n_embd, n_tokens}) ggml_tensor * input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - // x = self.layer_norm(x, self.w.blocks[0].ln0) - ggml_tensor * current = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); + ggml_tensor * x = llm_build_norm(ctx0, input_embeddings, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); for (int layer_i = 0; layer_i < n_layer; ++layer_i) { const llama_layer * layer = &model.layers[layer_i]; - current = llm_build_norm(ctx0, current, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); - current = llm_build_time_mix(ctx0, layer, current, layer_i); + // TODO: handle multiple kv cache cells + struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type)); + struct ggml_tensor * att_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, (kv_self.size - 1) * 2 * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)); + struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, ((kv_self.size - 1) * 2 + 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)); - current = llm_build_norm(ctx0, current, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); + struct ggml_tensor * x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, layer_i); + struct ggml_tensor * x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + x_prev = ggml_set_1d(ctx0, x_prev, att_shift, 0); + x_prev = ggml_set_1d( + ctx0, + x_prev, + ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0), + n_embd * ggml_type_size(x_prev->type) + ); + + x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, wkv_state)); + ggml_build_forward_expand(gf, x); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + ggml_view_1d( + ctx0, + x_norm, + n_embd, + (n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type) + ), + att_shift + ) + ); + + x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); + x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + x_prev = ggml_set_1d(ctx0, x_prev, ffn_shift, 0); + x_prev = ggml_set_1d( + ctx0, + x_prev, + ggml_view_1d(ctx0, x_norm, (n_tokens - 1) * n_embd, 0), + n_embd * ggml_type_size(x_prev->type) + ); + x = ggml_add(ctx0, x, llm_build_channel_mix(ctx0, layer, x_norm, x_prev)); + ggml_build_forward_expand(gf, x); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + ggml_view_1d( + ctx0, + x_norm, + n_embd, + (n_tokens - 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type) + ), + ffn_shift + ) + ); } // Something related to skipping tokens, specifics unclear ggml_tensor * inp_out_ids = build_inp_out_ids(); - current = ggml_get_rows(ctx0, current, inp_out_ids); + x = ggml_get_rows(ctx0, x, inp_out_ids); // Output head, convert result vector to logits - current = llm_build_norm(ctx0, current, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); - current = ggml_mul_mat(ctx0, model.output, current); + x = llm_build_norm(ctx0, x, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1); + x = ggml_mul_mat(ctx0, model.output, x); // Mark the output as being the result - cb(current, "result_output", -1); - ggml_build_forward_expand(gf, current); + cb(x, "result_output", -1); + ggml_build_forward_expand(gf, x); return gf; }