rwkv6: rename params
This commit is contained in:
parent
e198f7b9df
commit
3f75f12114
1 changed files with 8 additions and 8 deletions
|
@ -16611,20 +16611,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
const size_t T = dst->src[1]->ne[3];
|
const size_t T = dst->src[1]->ne[3];
|
||||||
const size_t C = dst->ne[0];
|
const size_t C = dst->ne[0];
|
||||||
const size_t H = dst->src[1]->ne[2];
|
const size_t HEADS = dst->src[1]->ne[2];
|
||||||
const size_t n_seqs = dst->src[5]->ne[1];
|
const size_t n_seqs = dst->src[5]->ne[1];
|
||||||
const size_t head_size = C / H;
|
const size_t head_size = C / HEADS;
|
||||||
|
|
||||||
float * dst_data = (float *) dst->data;
|
float * dst_data = (float *) dst->data;
|
||||||
float * state = ((float *) dst->data) + C * T;
|
float * state = ((float *) dst->data) + C * T;
|
||||||
|
|
||||||
if ((size_t)params->ith >= H) {
|
if ((size_t)params->ith >= HEADS) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t h_start = (H * params->ith) / params->nth;
|
size_t h_start = (HEADS * params->ith) / params->nth;
|
||||||
size_t h_end = ((H * (size_t)(params->ith + 1)) / (size_t)params->nth < H) ?
|
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
|
||||||
(H * (size_t)(params->ith + 1)) / (size_t)params->nth : H;
|
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
|
||||||
|
|
||||||
float * k = (float *) dst->src[0]->data;
|
float * k = (float *) dst->src[0]->data;
|
||||||
float * v = (float *) dst->src[1]->data;
|
float * v = (float *) dst->src[1]->data;
|
||||||
|
@ -16632,9 +16632,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
float * time_faaaa = (float *) dst->src[3]->data;
|
float * time_faaaa = (float *) dst->src[3]->data;
|
||||||
float * time_decay = (float *) dst->src[4]->data;
|
float * time_decay = (float *) dst->src[4]->data;
|
||||||
|
|
||||||
size_t t_stride = H * head_size;
|
size_t t_stride = HEADS * head_size;
|
||||||
|
|
||||||
size_t h_stride = C / H;
|
size_t h_stride = C / HEADS;
|
||||||
size_t h_stride_2d = head_size * head_size;
|
size_t h_stride_2d = head_size * head_size;
|
||||||
|
|
||||||
if (params->ith == 0) {
|
if (params->ith == 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue