ggml: rwkv_wkv: Avoid copying the state

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2024-08-31 12:17:08 +08:00
parent 5175375715
commit 846358d358

View file

@ -16874,7 +16874,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float * r = (float *) dst->src[2]->data;
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));
size_t t_stride = H * (C / H);
@ -16887,7 +16886,9 @@ static void ggml_compute_forward_rwkv_wkv_f32(
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
float * state_cur = state + (C / H) * C * (t / (T / n_seqs));
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
@ -16911,7 +16912,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_cur[h_2d_i_j_offset];
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;