diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 78f1aa3b2..17e92eff2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7675,7 +7675,8 @@ struct ggml_tensor * ggml_rwkv_wkv( is_node = true; } - const int64_t ne[4] = { S * H, n_tokens, 1, 1 }; + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_RWKV_WKV; @@ -16853,11 +16854,12 @@ static void ggml_compute_forward_add_rel_pos( static void ggml_compute_forward_rwkv_wkv_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const size_t T = dst->ne[1]; + const size_t T = dst->src[1]->ne[3]; const size_t C = dst->ne[0]; const size_t H = dst->src[1]->ne[2]; float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; if (params->ith != 0) { return; @@ -16870,7 +16872,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( float * r = (float *) dst->src[2]->data; float * time_faaaa = (float *) dst->src[3]->data; float * time_decay = (float *) dst->src[4]->data; - float * state = (float *) dst->src[5]->data; + memcpy(state, dst->src[5]->data, (C / H) * C * sizeof(float)); size_t t_stride = H * (C / H); diff --git a/src/llama.cpp b/src/llama.cpp index c755e728f..5e474d61d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9366,7 +9366,7 @@ static struct ggml_tensor * llm_build_time_mix( const struct llama_layer * layer, struct ggml_tensor * current, struct ggml_tensor * x_prev, - struct ggml_tensor * wkv_state) { + struct ggml_tensor ** wkv_state) { size_t n_embed = current->ne[0]; size_t n_tokens = current->ne[1]; size_t head_size = layer->time_mix_first->ne[0]; @@ -9509,13 +9509,15 @@ static struct ggml_tensor * llm_build_time_mix( w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed) ); - w = ggml_exp(ctx, ggml_neg_inplace(ctx, ggml_exp(ctx, w))); + w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w))); w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens); k = ggml_transpose(ctx, k); v = ggml_transpose(ctx, v); r = ggml_transpose(ctx, r); - current = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, wkv_state); + struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state); + current = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0); + *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size, n_embed * n_tokens * sizeof(float)); // ggml_group_norm considers groups in the third dimension. current = ggml_reshape_4d(ctx, current, 1, 1, n_embed, n_tokens); @@ -15096,7 +15098,7 @@ struct llm_build_context { const llama_layer * layer = &model.layers[layer_i]; // TODO: handle multiple kv cache cells - struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type)); + struct ggml_tensor * wkv_state = ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type)); struct ggml_tensor * att_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, (kv_self.size - 1) * 2 * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)); struct ggml_tensor * ffn_shift = ggml_view_1d(ctx0, kv_self.k_l[layer_i], n_embd, ((kv_self.size - 1) * 2 + 1) * n_embd * ggml_type_size(kv_self.k_l[layer_i]->type)); @@ -15110,7 +15112,7 @@ struct llm_build_context { n_embd * ggml_type_size(x_prev->type) ); - x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, wkv_state)); + x = ggml_add(ctx0, x, llm_build_time_mix(ctx0, layer, x_norm, x_prev, &wkv_state)); ggml_build_forward_expand(gf, x); ggml_build_forward_expand( gf, @@ -15125,6 +15127,14 @@ struct llm_build_context { att_shift ) ); + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + wkv_state, + ggml_view_1d(ctx0, kv_self.v_l[layer_i], hparams.n_embd_v_s(), (kv_self.size - 1) * hparams.n_embd_v_s() * ggml_type_size(kv_self.k_l[layer_i]->type)) + ) + ); x_norm = llm_build_norm(ctx0, x, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, layer_i); x_prev = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); @@ -15151,7 +15161,7 @@ struct llm_build_context { ) ); if ((layer_i + 1) % hparams.rescale_every_n_layers == 0) { - x = ggml_scale_inplace(ctx0, x, 0.5F); + x = ggml_scale(ctx0, x, 0.5F); } }