rwkv6: rename params
This commit is contained in:
parent
e198f7b9df
commit
3f75f12114
1 changed files with 8 additions and 8 deletions
|
@ -16611,20 +16611,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
struct ggml_tensor * dst) {
|
||||
const size_t T = dst->src[1]->ne[3];
|
||||
const size_t C = dst->ne[0];
|
||||
const size_t H = dst->src[1]->ne[2];
|
||||
const size_t HEADS = dst->src[1]->ne[2];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
const size_t head_size = C / H;
|
||||
const size_t head_size = C / HEADS;
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
if ((size_t)params->ith >= H) {
|
||||
if ((size_t)params->ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t h_start = (H * params->ith) / params->nth;
|
||||
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
|
||||
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
|
||||
size_t h_start = (HEADS * params->ith) / params->nth;
|
||||
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
|
||||
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
|
@ -16632,9 +16632,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
float * time_faaaa = (float *) dst->src[3]->data;
|
||||
float * time_decay = (float *) dst->src[4]->data;
|
||||
|
||||
size_t t_stride = H * head_size;
|
||||
size_t t_stride = HEADS * head_size;
|
||||
|
||||
size_t h_stride = C / H;
|
||||
size_t h_stride = C / HEADS;
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
if (params->ith == 0) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue