diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 93f3933e7..faf15170c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16874,7 +16874,6 @@ static void ggml_compute_forward_rwkv_wkv_f32( float * r = (float *) dst->src[2]->data; float * time_faaaa = (float *) dst->src[3]->data; float * time_decay = (float *) dst->src[4]->data; - memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float)); size_t t_stride = H * (C / H); @@ -16887,7 +16886,9 @@ static void ggml_compute_forward_rwkv_wkv_f32( // recursive through each token for (size_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; - float * state_cur = state + (C / H) * C * (t / (T / n_seqs)); + size_t state_offset = (C / H) * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; for (size_t h = 0; h < H; h++) { size_t h_offset = h * h_stride; @@ -16911,7 +16912,7 @@ static void ggml_compute_forward_rwkv_wkv_f32( float v_val = v[t_h_j_offset]; float kv_val = v_val * k_val; - float prev_state_val = state_cur[h_2d_i_j_offset]; + float prev_state_val = state_prev[h_2d_i_j_offset]; float temp_val = kv_val * time_faaaa_val + prev_state_val; dst_data[t_h_j_offset] += temp_val * r_val; state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;