rewrite to be more inline with the common pattern for distributing threads
This commit is contained in:
parent
a749ba7701
commit
4693b4611f
1 changed files with 8 additions and 5 deletions
|
@ -11655,13 +11655,16 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
if ((int64_t)params->ith >= HEADS) {
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t h_start = (HEADS * params->ith) / params->nth;
|
||||
int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ?
|
||||
(HEADS * (params->ith + 1)) / params->nth : HEADS;
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
|
@ -11675,7 +11678,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
if (params->ith == 0) {
|
||||
if (ith == 0) {
|
||||
memset(dst_data, 0, T * C * sizeof(float));
|
||||
}
|
||||
ggml_barrier(params->threadpool);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue