diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index 3b630b7b4..88eb5ffc0 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -11655,13 +11655,16 @@ static void ggml_compute_forward_rwkv_wkv6_f32( float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; - if ((int64_t)params->ith >= HEADS) { + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { return; } - int64_t h_start = (HEADS * params->ith) / params->nth; - int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ? - (HEADS * (params->ith + 1)) / params->nth : HEADS; + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -11675,7 +11678,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS size_t h_stride_2d = head_size * head_size; - if (params->ith == 0) { + if (ith == 0) { memset(dst_data, 0, T * C * sizeof(float)); } ggml_barrier(params->threadpool);