rwkv6: rename params

This commit is contained in:
Zhiyuan Li 2024-11-02 00:28:58 +11:00
parent e198f7b9df
commit 3f75f12114

View file

@ -16611,20 +16611,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
struct ggml_tensor * dst) {
const size_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2];
const size_t HEADS = dst->src[1]->ne[2];
const size_t n_seqs = dst->src[5]->ne[1];
const size_t head_size = C / H;
const size_t head_size = C / HEADS;
float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
if ((size_t)params->ith >= H) {
if ((size_t)params->ith >= HEADS) {
return;
}
size_t h_start = (H * params->ith) / params->nth;
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
size_t h_start = (HEADS * params->ith) / params->nth;
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -16632,9 +16632,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
size_t t_stride = H * head_size;
size_t t_stride = HEADS * head_size;
size_t h_stride = C / H;
size_t h_stride = C / HEADS;
size_t h_stride_2d = head_size * head_size;
if (params->ith == 0) {