ggml: rwkv_wkv: Avoid copying the state
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
5175375715
commit
846358d358
1 changed files with 4 additions and 3 deletions
|
@ -16874,7 +16874,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
float * r = (float *) dst->src[2]->data;
|
||||
float * time_faaaa = (float *) dst->src[3]->data;
|
||||
float * time_decay = (float *) dst->src[4]->data;
|
||||
memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));
|
||||
|
||||
size_t t_stride = H * (C / H);
|
||||
|
||||
|
@ -16887,7 +16886,9 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
// recursive through each token
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
float * state_cur = state + (C / H) * C * (t / (T / n_seqs));
|
||||
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
|
@ -16911,7 +16912,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_cur[h_2d_i_j_offset];
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue