diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 748e18d17..0b8cbaeb4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -513,7 +513,7 @@ extern "C" { GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, GGML_OP_RWKV_WKV6, - GGML_OP_GATED_LINEAR_ATTENTION, + GGML_OP_GATED_LINEAR_ATTN, GGML_OP_UNARY, @@ -1876,11 +1876,12 @@ extern "C" { GGML_API struct ggml_tensor * ggml_gated_linear_attn( struct ggml_context * ctx, - struct ggml_tensor * q, struct ggml_tensor * k, struct ggml_tensor * v, + struct ggml_tensor * q, struct ggml_tensor * g, - struct ggml_tensor * state); + struct ggml_tensor * state, + float scale); // custom operators diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 30a6d333b..ada6d37d1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6( } } +// ggml_compute_forward_gla + +static void ggml_compute_forward_gla_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const int64_t T = dst->src[1]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[1]; + const int64_t n_seqs = dst->src[4]->ne[1]; + const int64_t head_size = C / HEADS; + const float scale = ggml_get_op_params_f32(dst, 0); + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { + return; + } + + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * q = (float *) dst->src[2]->data; + float * g = (float *) dst->src[3]->data; + + size_t t_stride = HEADS * head_size; // Same to C + + size_t h_stride = C / HEADS; + GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS + size_t h_stride_2d = head_size * head_size; + + if (ith == 0) { + memset(dst_data, 0, T * C * sizeof(float)); + } + ggml_barrier(params->threadpool); + + + #if defined(__AVX__) && !defined(__AVX512F__) + #define GGML_F32X GGML_F32x8 + #define GGML_F32X_SET1 GGML_F32x8_SET1 + #define GGML_F32X_LOAD GGML_F32x8_LOAD + #define GGML_F32X_STORE GGML_F32x8_STORE + #define GGML_F32X_MUL GGML_F32x8_MUL + #define GGML_F32X_FMA GGML_F32x8_FMA + #define GLA_VECTOR_SIZE 8 + #elif defined(__AVX512F__) + #define GGML_F32X GGML_F32x16 + #define GGML_F32X_SET1 GGML_F32x16_SET1 + #define GGML_F32X_LOAD GGML_F32x16_LOAD + #define GGML_F32X_STORE GGML_F32x16_STORE + #define GGML_F32X_MUL GGML_F32x16_MUL + #define GGML_F32X_FMA GGML_F32x16_FMA + #define GLA_VECTOR_SIZE 16 + #elif defined(__ARM_NEON) && defined(__aarch64__) + #define GGML_F32X GGML_F32x4 + #define GGML_F32X_SET1 GGML_F32x4_SET1 + #define GGML_F32X_LOAD GGML_F32x4_LOAD + #define GGML_F32X_STORE GGML_F32x4_STORE + #define GGML_F32X_MUL GGML_F32x4_MUL + #define GGML_F32X_FMA GGML_F32x4_FMA + #define WKV_VECTOR_SIZE 4 + #endif + + #ifdef GLA_VECTOR_SIZE + const int64_t vec_count = head_size / GLA_VECTOR_SIZE; + + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float q_val = q[t_h_i_offset] * scale; + float g_val = g[t_h_i_offset]; + + // Broadcast scalar values to vectors + GGML_F32X k_vec = GGML_F32X_SET1(k_val); + GGML_F32X q_vec = GGML_F32X_SET1(q_val); + GGML_F32X g_vec = GGML_F32X_SET1(g_val); + + for (int64_t j = 0; j < vec_count; j++) { + size_t base_j = j * GLA_VECTOR_SIZE; + size_t t_h_j_offset = t_h_offset + base_j; + size_t h_2d_i_j_offset = h_2d_i_offset + base_j; + + // Load x elements at once + GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]); + GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); + GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]); + + // Compute kv = v * k + GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec); + + // Compute temp = prev_state * g + kv + GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec); + + // Update dst: dst += temp * q + dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec); + GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); + + // Update state + GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec); + } + + // Handle remaining elements, this will not be used. + for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val + prev_state_val * g_val; + dst_data[t_h_j_offset] += temp_val * q_val; + state_cur[h_2d_i_j_offset] = temp_val; + } + } + } + } + + #else + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float q_val = q[t_h_i_offset] * scale; + float g_val = g[t_h_i_offset]; + + for (int64_t j = 0; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = prev_state_val * g_val + kv_val; + dst_data[t_h_j_offset] += temp_val * q_val; + state_cur[h_2d_i_j_offset] = temp_val; + } + } + } + } + #endif +} + + +static void ggml_compute_forward_gla( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gla_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv6(params, tensor); } break; + case GGML_OP_GATED_LINEAR_ATTN: + { + ggml_compute_forward_gla(params, tensor); + } break; case GGML_OP_MAP_UNARY: { ggml_unary_op_f32_t fun; @@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c180adc84..e82b535ae 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -37,6 +37,7 @@ #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv6.cuh" +#include "ggml-cuda/gla.cuh" #include #include @@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_cuda_op_gated_linear_attn(ctx, dst); + break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_cuda_cross_entropy_loss_back(ctx, dst); break; @@ -3010,6 +3014,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: return true; case GGML_OP_FLASH_ATTN_EXT: { #ifndef FLASH_ATTN_AVAILABLE diff --git a/ggml/src/ggml-cuda/gla.cu b/ggml/src/ggml-cuda/gla.cu new file mode 100644 index 000000000..c18d62858 --- /dev/null +++ b/ggml/src/ggml-cuda/gla.cu @@ -0,0 +1,92 @@ +#include "common.cuh" +#include "gla.cuh" + +template +static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale, + const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = HEAD_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& k = (float4&)(_k[j]); + const float4& r = (float4&)(_r[j]); + const float4& td = (float4&)(_td[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + + y += r.x * s.x; + y += r.y * s.y; + y += r.z * s.z; + y += r.w * s.w; + } + dst[t] = y * scale; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * td_d = (const float *)dst->src[3]->data; + const float * s_d = (const float *)dst->src[4]->data; + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + const float scale = ((const float*)(dst->op_params))[0]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + + if (C / H == 64) { + gated_linear_attn_f32<64><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32<128><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/ggml/src/ggml-cuda/gla.cuh b/ggml/src/ggml-cuda/gla.cuh new file mode 100644 index 000000000..2c82ad7dd --- /dev/null +++ b/ggml/src/ggml-cuda/gla.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 739c94166..2485203ef 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1065,7 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "get_rel_pos(x)", "add_rel_pos(x)", "rwkv_wkv6(k, v, r, tf, td, s)", - "gated_linear_attn(k, v, q, decay, s)", + "gated_linear_attn(k, v, q, gate, s)", "unary(x)", @@ -4667,6 +4667,49 @@ struct ggml_tensor * ggml_rwkv_wkv6( return result; } +// ggml_gated_linear_attn + +struct ggml_tensor * ggml_gated_linear_attn( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * q, + struct ggml_tensor * g, + struct ggml_tensor * state, + float scale) { + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_f32(result, 0, scale); + + result->op = GGML_OP_GATED_LINEAR_ATTN; + result->src[0] = k; + result->src[1] = v; + result->src[2] = q; + result->src[3] = g; + result->src[4] = state; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ee6d24391..85c93e5ac 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2021,6 +2021,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_T5ENCODER: case LLM_ARCH_JAIS: case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_WAVTOKENIZER_DEC: return LLAMA_ROPE_TYPE_NONE; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 42974f8f1..502499152 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -622,6 +622,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks + if (!llama_model_is_recurrent(&model)) { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads diff --git a/src/llama.cpp b/src/llama.cpp index 35230d1ce..924208930 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3296,13 +3296,14 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( struct ggml_tensor * cur, struct ggml_tensor * x_prev, struct ggml_tensor ** wkv_state, + size_t wkv_head_size, size_t head_count_kv) { size_t n_embd = cur->ne[0]; size_t n_seq_tokens = cur->ne[1]; size_t n_seqs = cur->ne[2]; - size_t head_size = layer->time_mix_first->ne[0]; - size_t head_count = layer->time_mix_first->ne[1]; + size_t head_size = wkv_head_size; + size_t head_count = n_embd / head_size; size_t n_tokens = n_seqs * n_seq_tokens; @@ -3336,11 +3337,21 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( xxx ); - struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0); - struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); - struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); - struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); - struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + struct ggml_tensor *mw, *mk, *mv, *mr, *mg; + if (is_qrwkv) { + // Why the f*** do they change the order here? + mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0); + mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + } else { + mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0); + mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float)); + mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float)); + mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float)); + mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float)); + } struct ggml_tensor * xw = ggml_add( ctx, @@ -3404,20 +3415,27 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( if (layer->time_mix_value_b) { v = ggml_add(ctx, v, layer->time_mix_value_b); } - r = ggml_reshape_3d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, head_count, n_tokens); - k = ggml_reshape_3d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), head_size, head_count_kv, n_tokens); - v = ggml_reshape_3d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, head_count_kv, n_tokens); - struct ggml_tensor * g = ggml_silu( - ctx, - llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg) - ); + + struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg); + if (is_qrwkv) { + g = ggml_sigmoid(ctx, g); + } else { + g = ggml_silu(ctx, g); + } if (head_count_kv != head_count) { GGML_ASSERT(head_count % head_count_kv == 0); - k = ggml_repeat(ctx, k, r); - v = ggml_repeat(ctx, v, r); + k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens); + v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens); + struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens); + k = ggml_repeat(ctx, k, tmp); + v = ggml_repeat(ctx, v, tmp); } + k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens); + v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens); + r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens); + struct ggml_tensor * w = ggml_mul_mat( ctx, layer->time_mix_decay_w2, @@ -3438,7 +3456,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix( struct ggml_tensor * wkv_output; if (!layer->time_mix_first) { - // TODO: GatedLinearAttention + wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f)); } else { wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); } @@ -9873,7 +9891,7 @@ struct llm_build_context { 1 ); - cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, layer->time_mix_first->ne[1])); + cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size)); ggml_build_forward_expand(gf, cur); ggml_build_forward_expand( gf, @@ -9996,7 +10014,22 @@ struct llm_build_context { ) ); - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.n_head_kv())); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv())); + ggml_build_forward_expand(gf, ffn_inp); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_states, + ggml_view_1d( + ctx0, + kv_self.v_l[il], + hparams.n_embd_v_s() * n_seqs, + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) + ) + ) + ); + cb(ffn_inp, "ffn_inp", il); // feed-forward network