diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1cd168db2..3bc0513ca 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1970,6 +1970,228 @@ ggml_tensor * llama_context::build_mamba_layer( } +ggml_tensor * llama_context::build_rwkv_token_shift_load( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + const auto & hparams = model.hparams; + + const auto token_shift_count = hparams.token_shift_count; + + const auto & n_tokens = ubatch.n_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + struct ggml_tensor * token_shift_all = kv_self.k_l[il]; + + struct ggml_tensor * token_shift = build_copy_mask_state( + ctx0, graph, token_shift_all, state_copy, state_mask, + n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case); + + token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); + + return token_shift; +} + + +ggml_tensor * llama_context::build_rwkv_token_shift_store( + ggml_context * ctx0, + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + const auto & hparams = model.hparams; + + const auto token_shift_count = hparams.token_shift_count; + const auto n_embd = hparams.n_embd; + + const auto & n_tokens = ubatch.n_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + + return ggml_cpy( + ctx0, + ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), + ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) + ); +} + + +ggml_tensor * llama_context::build_rwkv6_time_mix( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) { + const auto & hparams = model.hparams; + + const auto n_tokens = ubatch.n_tokens; + const auto n_seqs = ubatch.n_seqs; + const auto n_embd = hparams.n_embd; + const auto head_size = hparams.wkv_head_size; + const auto n_head = n_embd / head_size; + const auto n_head_kv = hparams.n_head_kv(il); + + const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + + const auto layer = &model.layers[il]; + + bool is_qrwkv = layer->time_mix_first == nullptr; + + struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur); + + xxx = ggml_reshape_4d( + ctx0, + ggml_tanh( + ctx0, + ggml_mul_mat(ctx0, layer->time_mix_w1, xxx) + ), + layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens + ); + + xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2)); + + xxx = ggml_mul_mat( + ctx0, + ggml_reshape_4d( + ctx0, + layer->time_mix_w2, + layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5 + ), + xxx + ); + + struct ggml_tensor *xw, *xk, *xv, *xr, *xg; + if (layer->time_mix_lerp_fused) { + // fusing these weights makes some performance improvement + sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens); + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur); + xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + } else { + // for backward compatibility + xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); + xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + + xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur); + xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur); + xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur); + xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur); + xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur); + } + + struct ggml_tensor * r = build_lora_mm(ctx0, layer->time_mix_receptance, xr); + struct ggml_tensor * k = build_lora_mm(ctx0, layer->time_mix_key, xk); + struct ggml_tensor * v = build_lora_mm(ctx0, layer->time_mix_value, xv); + if (layer->time_mix_receptance_b) { + r = ggml_add(ctx0, r, layer->time_mix_receptance_b); + } + if (layer->time_mix_key_b) { + k = ggml_add(ctx0, k, layer->time_mix_key_b); + } + if (layer->time_mix_value_b) { + v = ggml_add(ctx0, v, layer->time_mix_value_b); + } + + struct ggml_tensor * g = build_lora_mm(ctx0, layer->time_mix_gate, xg); + if (is_qrwkv) { + g = ggml_sigmoid(ctx0, g); + } else { + g = ggml_silu(ctx0, g); + } + + if (n_head_kv != 0 && n_head_kv != n_head) { + GGML_ASSERT(n_head % n_head_kv == 0); + k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens); + v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens); + struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens); + k = ggml_repeat(ctx0, k, tmp); + v = ggml_repeat(ctx0, v, tmp); + } + + k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens); + v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens); + r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens); + + struct ggml_tensor * w = ggml_mul_mat( + ctx0, + layer->time_mix_decay_w2, + ggml_tanh( + ctx0, + ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw) + ) + ); + + w = ggml_add(ctx0, w, layer->time_mix_decay); + w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w))); + w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens); + + if (is_qrwkv) { + // k = k * (1 - w) + k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); + } + + struct ggml_tensor * wkv_state = build_copy_mask_state( + ctx0, graph, kv_self.v_l[il], state_copy, state_mask, + n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case); + + struct ggml_tensor * wkv_output; + if (is_qrwkv) { + wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f)); + } else { + wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, wkv_state); + } + cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); + wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); + + ggml_build_forward_expand( + graph, + ggml_cpy( + ctx0, + wkv_state, + ggml_view_1d( + ctx0, + kv_self.v_l[il], + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) + ) + ) + ); + + if (!is_qrwkv) { + // group norm with head_count groups + cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens); + cur = ggml_norm(ctx0, cur, 64e-5f); + + // Convert back to regular vectors. + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b); + } else { + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + } + + cur = ggml_mul(ctx0, cur, g); + cur = build_lora_mm(ctx0, layer->time_mix_output, cur); + + return cur; +} + // llama output size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { diff --git a/src/llama-context.h b/src/llama-context.h index 5958deaef..4cf4a6312 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -248,6 +248,33 @@ struct llama_context { int il, bool worst_case); + ggml_tensor * build_rwkv_token_shift_load( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case); + + ggml_tensor * build_rwkv_token_shift_store( + ggml_context * ctx0, + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il, + bool worst_case); + + ggml_tensor * build_rwkv6_time_mix( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case); + struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] diff --git a/src/llama.cpp b/src/llama.cpp index 64a5efd2d..171ea2017 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -574,175 +574,34 @@ struct llm_build_context { return cur; } - //struct ggml_tensor * build_rwkv6_time_mix( - // const struct llama_layer * layer, - // struct ggml_tensor * cur, - // struct ggml_tensor * x_prev, - // struct ggml_tensor ** wkv_state, - // size_t wkv_head_size, - // size_t head_count_kv) { - // size_t n_embd = cur->ne[0]; - // size_t n_seq_tokens = cur->ne[1]; - // size_t n_seqs = cur->ne[2]; + struct ggml_tensor * build_rwkv_channel_mix( + const struct llama_layer * layer, + struct ggml_tensor * cur, + struct ggml_tensor * x_prev, + const llm_arch arch) { + struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + switch (arch) { + case LLM_ARCH_RWKV6: + { + struct ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + struct ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); - // size_t head_size = wkv_head_size; - // size_t head_count = n_embd / head_size; + struct ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); + struct ggml_tensor * k = ggml_sqr( + ctx0, + ggml_relu( + ctx0, + build_lora_mm(layer->channel_mix_key, xk) + ) + ); + cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); + } break; + default: + GGML_ABORT("fatal error"); + } - // size_t n_tokens = n_seqs * n_seq_tokens; - - // bool is_qrwkv = layer->time_mix_first == nullptr; - - // struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); - - // sx = ggml_reshape_2d(ctx0, sx, n_embd, n_tokens); - // cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - - // struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur); - - // xxx = ggml_reshape_4d( - // ctx0, - // ggml_tanh( - // ctx0, - // ggml_mul_mat(ctx0, layer->time_mix_w1, xxx) - // ), - // layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens - // ); - - // xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2)); - - // xxx = ggml_mul_mat( - // ctx0, - // ggml_reshape_4d( - // ctx0, - // layer->time_mix_w2, - // layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5 - // ), - // xxx - // ); - - // struct ggml_tensor *xw, *xk, *xv, *xr, *xg; - // if (layer->time_mix_lerp_fused) { - // // fusing these weights makes some performance improvement - // sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens); - // cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); - // xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur); - // xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); - // xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); - // xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); - // xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); - // xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); - // } else { - // // for backward compatibility - // xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0); - // xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); - // xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); - // xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); - // xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); - - // xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur); - // xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur); - // xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur); - // xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur); - // xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur); - // } - - // struct ggml_tensor * r = build_lora_mm(layer->time_mix_receptance, xr); - // struct ggml_tensor * k = build_lora_mm(layer->time_mix_key, xk); - // struct ggml_tensor * v = build_lora_mm(layer->time_mix_value, xv); - // if (layer->time_mix_receptance_b) { - // r = ggml_add(ctx0, r, layer->time_mix_receptance_b); - // } - // if (layer->time_mix_key_b) { - // k = ggml_add(ctx0, k, layer->time_mix_key_b); - // } - // if (layer->time_mix_value_b) { - // v = ggml_add(ctx0, v, layer->time_mix_value_b); - // } - - // struct ggml_tensor * g = build_lora_mm(layer->time_mix_gate, xg); - // if (is_qrwkv) { - // g = ggml_sigmoid(ctx0, g); - // } else { - // g = ggml_silu(ctx0, g); - // } - - // if (head_count_kv != head_count) { - // GGML_ASSERT(head_count % head_count_kv == 0); - // k = ggml_reshape_4d(ctx0, k, head_size, 1, head_count_kv, n_tokens); - // v = ggml_reshape_4d(ctx0, v, head_size, 1, head_count_kv, n_tokens); - // struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens); - // k = ggml_repeat(ctx0, k, tmp); - // v = ggml_repeat(ctx0, v, tmp); - // } - - // k = ggml_reshape_3d(ctx0, k, head_size, head_count, n_tokens); - // v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); - // r = ggml_reshape_3d(ctx0, r, head_size, head_count, n_tokens); - - // struct ggml_tensor * w = ggml_mul_mat( - // ctx0, - // layer->time_mix_decay_w2, - // ggml_tanh( - // ctx0, - // ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw) - // ) - // ); - - // w = ggml_add(ctx0, w, layer->time_mix_decay); - // w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w))); - // w = ggml_reshape_3d(ctx0, w, head_size, head_count, n_tokens); - - // if (is_qrwkv) { - // // k = k * (1 - w) - // k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); - // } - - // struct ggml_tensor * wkv_output; - // if (!layer->time_mix_first) { - // wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, *wkv_state, pow(head_size, -0.5f)); - // } else { - // wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, *wkv_state); - // } - // cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0); - // *wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float)); - - // if (!is_qrwkv) { - // // group norm with head_count groups - // cur = ggml_reshape_3d(ctx0, cur, n_embd / head_count, head_count, n_tokens); - // cur = ggml_norm(ctx0, cur, 64e-5f); - - // // Convert back to regular vectors. - // cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - // cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b); - // } else { - // cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - // } - - // cur = ggml_mul(ctx0, cur, g); - // cur = build_lora_mm(layer->time_mix_output, cur); - - // return ggml_reshape_3d(ctx0, cur, n_embd, n_seq_tokens, n_seqs); - //} - - //struct ggml_tensor * build_rwkv6_channel_mix( - // const struct llama_layer * layer, - // struct ggml_tensor * cur, - // struct ggml_tensor * x_prev) { - // struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); - // struct ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); - // struct ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); - - // struct ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); - // struct ggml_tensor * k = ggml_sqr( - // ctx0, - // ggml_relu( - // ctx0, - // build_lora_mm(layer->channel_mix_key, xk) - // ) - // ); - - // return ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); - //} + return cur; + } struct ggml_cgraph * build_k_shift() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -6935,226 +6794,178 @@ struct llm_build_context { return gf; } - //ggml_cgraph * build_rwkv6() { - // struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + ggml_cgraph * build_rwkv6() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // // Token shift state dimensions should be 2 * n_emb - // GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); + GGML_ASSERT(hparams.token_shift_count == 2); - // const int64_t n_seqs = ubatch.n_seqs; - // const int64_t n_seq_tokens = ubatch.n_seq_tokens; - // const int64_t n_tokens = ubatch.n_tokens; - // GGML_ASSERT(n_seqs != 0); - // GGML_ASSERT(ubatch.equal_seqs); - // GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); + struct ggml_tensor * cur; + struct ggml_tensor * inpL; - // struct ggml_tensor * cur; - // struct ggml_tensor * inpL; - // struct ggml_tensor * state_copy = build_inp_s_copy(); - // struct ggml_tensor * state_mask = build_inp_s_mask(); + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - // inpL = build_inp_embd(model.tok_embd); - // inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + struct ggml_tensor * state_copy = lctx.build_inp_s_copy(ctx0, worst_case); + struct ggml_tensor * state_mask = lctx.build_inp_s_mask(ctx0, worst_case); - // for (int il = 0; il < n_layer; ++il) { - // const llama_layer * layer = &model.layers[il]; + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; - // // (ab)using the KV cache to store the states - // struct ggml_tensor * token_shift = build_copy_mask_state( - // gf, kv_self.k_l[il], state_copy, state_mask, - // hparams.n_embd_k_s(), n_seqs); + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; - // struct ggml_tensor * wkv_states = build_copy_mask_state( - // gf, kv_self.v_l[il], state_copy, state_mask, - // hparams.n_embd_v_s(), n_seqs); + struct ggml_tensor * token_shift = lctx.build_rwkv_token_shift_load( + ctx0, gf, state_copy, state_mask, ubatch, il, worst_case + ); - // cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - // token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs); + struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); + struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); - // struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); - // struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); + struct ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il); + cb(att_norm, "attn_norm", il); - // struct ggml_tensor * x_norm_att = build_norm(cur, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il); - // struct ggml_tensor * x_prev = ggml_concat( - // ctx0, - // att_shift, - // ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0), - // 1 - // ); + struct ggml_tensor * x_prev = ggml_concat( + ctx0, + att_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); - // cur = ggml_add(ctx0, cur, build_rwkv6_time_mix(layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size)); - // ggml_build_forward_expand(gf, cur); - // ggml_build_forward_expand( - // gf, - // ggml_cpy( - // ctx0, - // wkv_states, - // ggml_view_1d( - // ctx0, - // kv_self.v_l[il], - // hparams.n_embd_v_s() * n_seqs, - // hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) - // ) - // ) - // ); + cur = lctx.build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il, worst_case); - // struct ggml_tensor * x_norm_ffn = build_norm(cur, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il); - // x_prev = ggml_concat( - // ctx0, - // ffn_shift, - // ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0), - // 1 - // ); - // cur = ggml_add(ctx0, cur, build_rwkv6_channel_mix(layer, x_norm_ffn, x_prev)); - // ggml_build_forward_expand(gf, cur); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); - // struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); - // struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn)); + struct ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il); + cb(ffn_norm, "ffn_norm", il); - // token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1); + x_prev = ggml_concat( + ctx0, + ffn_shift, + ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0), + 1 + ); - // ggml_build_forward_expand( - // gf, - // ggml_cpy( - // ctx0, - // ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0), - // ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) - // ) - // ); + cur = build_rwkv_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6); + cur = ggml_add(ctx0, cur, ffn_inp); - // if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { - // cur = ggml_scale(ctx0, cur, 0.5F); - // } + token_shift = ggml_concat(ctx0, + ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)), + ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)), + 1 + ); + ggml_build_forward_expand(gf, lctx.build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il, worst_case)); - // cur = lctx.cvec.apply_to(ctx0, cur, il); - // cb(cur, "l_out", il); + if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) { + cur = ggml_scale(ctx0, cur, 0.5F); + } - // // input for next layer - // inpL = cur; - // } + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); - // cur = inpL; - // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - // cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - // cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // input for next layer + inpL = cur; + } - // cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); - // cb(cur, "result_norm", -1); + cur = inpL; + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); - // cur = build_lora_mm(model.output, cur); - // cb(cur, "result_output", -1); + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1); + cb(cur, "result_norm", -1); - // ggml_build_forward_expand(gf, cur); + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); - // return gf; - //} + ggml_build_forward_expand(gf, cur); + + return gf; + } // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py - //ggml_cgraph * build_rwkv6qwen2() { - // struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + ggml_cgraph * build_rwkv6qwen2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // GGML_ASSERT(n_embd == hparams.n_embd_k_s()); + GGML_ASSERT(n_embd == hparams.n_embd_k_s()); - // const int64_t n_seqs = ubatch.n_seqs; - // const int64_t n_seq_tokens = ubatch.n_seq_tokens; - // const int64_t n_tokens = ubatch.n_tokens; - // GGML_ASSERT(n_seqs != 0); - // GGML_ASSERT(ubatch.equal_seqs); - // GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs); + struct ggml_tensor * cur; + struct ggml_tensor * inpL; - // struct ggml_tensor * cur; - // struct ggml_tensor * inpL; - // struct ggml_tensor * state_copy = build_inp_s_copy(); - // struct ggml_tensor * state_mask = build_inp_s_mask(); + inpL = build_inp_embd(model.tok_embd); - // inpL = build_inp_embd(model.tok_embd); + struct ggml_tensor * state_copy = lctx.build_inp_s_copy(ctx0, worst_case); + struct ggml_tensor * state_mask = lctx.build_inp_s_mask(ctx0, worst_case); - // for (int il = 0; il < n_layer; ++il) { - // const llama_layer * layer = &model.layers[il]; + const auto n_embd = hparams.n_embd; + const auto n_seq_tokens = ubatch.n_seq_tokens; + const auto n_seqs = ubatch.n_seqs; - // // (ab)using the KV cache to store the states - // struct ggml_tensor * token_shift = build_copy_mask_state( - // gf, kv_self.k_l[il], state_copy, state_mask, - // hparams.n_embd_k_s(), n_seqs); + inpL = build_inp_embd(model.tok_embd); - // struct ggml_tensor * wkv_states = build_copy_mask_state( - // gf, kv_self.v_l[il], state_copy, state_mask, - // hparams.n_embd_v_s(), n_seqs); + for (int il = 0; il < n_layer; ++il) { + const llama_layer * layer = &model.layers[il]; - // cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - // token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs); + struct ggml_tensor * token_shift = lctx.build_rwkv_token_shift_load( + ctx0, gf, state_copy, state_mask, ubatch, il, worst_case + ); - // struct ggml_tensor * x_norm_att = build_norm(cur, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); - // struct ggml_tensor * x_prev = ggml_concat( - // ctx0, - // token_shift, - // ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0), - // 1 - // ); + struct ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); + cb(att_norm, "attn_norm", il); - // struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att)); - // ggml_build_forward_expand( - // gf, - // ggml_cpy( - // ctx0, - // ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0), - // ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) + struct ggml_tensor * x_prev = ggml_concat( + ctx0, + token_shift, + ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0), + 1 + ); - // struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, build_rwkv6_time_mix(layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv())); - // ggml_build_forward_expand(gf, ffn_inp); - // ggml_build_forward_expand( - // gf, - // ggml_cpy( - // ctx0, - // wkv_states, - // ggml_view_1d( - // ctx0, - // kv_self.v_l[il], - // hparams.n_embd_v_s() * n_seqs, - // hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) - // ) - // ) - // ); + cur = lctx.build_rwkv6_time_mix(ctx0, gf, att_norm, x_prev, state_copy, state_mask, ubatch, il, worst_case); - // cb(ffn_inp, "ffn_inp", il); + token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); + ggml_build_forward_expand(gf, lctx.build_rwkv_token_shift_store(ctx0, token_shift, ubatch, il, worst_case)); - // // feed-forward network - // cur = build_norm(ffn_inp, - // model.layers[il].ffn_norm, NULL, - // LLM_NORM_RMS, il); - // cb(cur, "ffn_norm", il); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); - // cur = build_ffn(cur, - // model.layers[il].ffn_up, NULL, NULL, - // model.layers[il].ffn_gate, NULL, NULL, - // model.layers[il].ffn_down, NULL, NULL, - // NULL, - // LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - // cb(cur, "ffn_out", il); + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); - // cur = ggml_add(ctx0, cur, ffn_inp); - // cur = lctx.cvec.apply_to(ctx0, cur, il); - // cb(cur, "l_out", il); + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); - // // input for next layer - // inpL = cur; - // } + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); - // cur = inpL; - // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - // cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); - // cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // input for next layer + inpL = cur; + } - // cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); - // cb(cur, "result_norm", -1); + cur = inpL; + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); - // cur = build_lora_mm(model.output, cur); - // cb(cur, "result_output", -1); + cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); - // ggml_build_forward_expand(gf, cur); + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); - // return gf; - //} + ggml_build_forward_expand(gf, cur); + + return gf; + } // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: @@ -7726,14 +7537,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_exaone(); } break; - //case LLM_ARCH_RWKV6: - // { - // result = llm.build_rwkv6(); - // } break; - //case LLM_ARCH_RWKV6QWEN2: - // { - // result = llm.build_rwkv6qwen2(); - // } break; + case LLM_ARCH_RWKV6: + { + result = llm.build_rwkv6(); + } break; + case LLM_ARCH_RWKV6QWEN2: + { + result = llm.build_rwkv6qwen2(); + } break; case LLM_ARCH_CHAMELEON: { result = llm.build_chameleon();