From 3f75f121142dc46b3184efd1e6a0d66e05997a72 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sat, 2 Nov 2024 00:28:58 +1100 Subject: [PATCH] rwkv6: rename params --- ggml/src/ggml.c | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 16ecf8bf4..84f0e8201 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16611,20 +16611,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32( struct ggml_tensor * dst) { const size_t T = dst->src[1]->ne[3]; const size_t C = dst->ne[0]; - const size_t H = dst->src[1]->ne[2]; + const size_t HEADS = dst->src[1]->ne[2]; const size_t n_seqs = dst->src[5]->ne[1]; - const size_t head_size = C / H; + const size_t head_size = C / HEADS; float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; - if ((size_t)params->ith >= H) { + if ((size_t)params->ith >= HEADS) { return; } - size_t h_start = (H * params->ith) / params->nth; - size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ? - (H * (size_t)(params->ith + 1)) / (size_t)params->nth : H; + size_t h_start = (HEADS * params->ith) / params->nth; + size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ? + (HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -16632,9 +16632,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * time_faaaa = (float *) dst->src[3]->data; float * time_decay = (float *) dst->src[4]->data; - size_t t_stride = H * head_size; + size_t t_stride = HEADS * head_size; - size_t h_stride = C / H; + size_t h_stride = C / HEADS; size_t h_stride_2d = head_size * head_size; if (params->ith == 0) {