rewrite to be more inline with the common pattern for distributing threads

This commit is contained in:
Zhiyuan Li 2024-11-05 02:49:22 +11:00
parent a749ba7701
commit 4693b4611f

View file

@ -11655,13 +11655,16 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * dst_data = (float *) dst->data; float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T; float * state = ((float *) dst->data) + C * T;
if ((int64_t)params->ith >= HEADS) { const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return; return;
} }
int64_t h_start = (HEADS * params->ith) / params->nth; const int h_start = (HEADS * ith) / nth;
int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ? const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (params->ith + 1)) / params->nth : HEADS; (HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data; float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data; float * v = (float *) dst->src[1]->data;
@ -11675,7 +11678,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
size_t h_stride_2d = head_size * head_size; size_t h_stride_2d = head_size * head_size;
if (params->ith == 0) { if (ith == 0) {
memset(dst_data, 0, T * C * sizeof(float)); memset(dst_data, 0, T * C * sizeof(float));
} }
ggml_barrier(params->threadpool); ggml_barrier(params->threadpool);