diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index 8fa3f582e..c66e90a9f 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -11646,22 +11646,22 @@ static void ggml_compute_forward_add_rel_pos( static void ggml_compute_forward_rwkv_wkv6_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const size_t T = dst->src[1]->ne[3]; - const size_t C = dst->ne[0]; - const size_t HEADS = dst->src[1]->ne[2]; - const size_t n_seqs = dst->src[5]->ne[1]; - const size_t head_size = C / HEADS; + const int64_t T = dst->src[1]->ne[3]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[2]; + const int64_t n_seqs = dst->src[5]->ne[1]; + const int64_t head_size = C / HEADS; float * dst_data = (float *) dst->data; float * state = ((float *) dst->data) + C * T; - if ((size_t)params->ith >= HEADS) { + if ((int64_t)params->ith >= HEADS) { return; } - size_t h_start = (HEADS * params->ith) / params->nth; - size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ? - (HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS; + int64_t h_start = (HEADS * params->ith) / params->nth; + int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ? + (HEADS * (params->ith + 1)) / params->nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -11708,20 +11708,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #endif #ifdef WKV_VECTOR_SIZE - const size_t vec_count = head_size / WKV_VECTOR_SIZE; + const int64_t vec_count = head_size / WKV_VECTOR_SIZE; - for (size_t t = 0; t < T; t++) { + for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; size_t state_offset = head_size * C * (t / (T / n_seqs)); float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = h_start; h < h_end; h++) { + for (int64_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < head_size; i++) { + for (int64_t i = 0; i < head_size; i++) { size_t t_h_i_offset = t_h_offset + i; size_t h_i_offset = h_offset + i; size_t h_2d_i_offset = h_2d_offset + i * h_stride; @@ -11737,7 +11737,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val); GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); - for (size_t j = 0; j < vec_count; j++) { + for (int64_t j = 0; j < vec_count; j++) { size_t base_j = j * WKV_VECTOR_SIZE; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -11763,7 +11763,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( } // Handle remaining elements, this will not be used. - for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * VECTOR_SIZE; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -11782,18 +11782,18 @@ static void ggml_compute_forward_rwkv_wkv6_f32( // dst = r @ (time_faaaa * (k @ v) + state), // state = time_decay * state + (k @ v), // recursive through each token - for (size_t t = 0; t < T; t++) { + for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; size_t state_offset = head_size * C * (t / (T / n_seqs)); float * state_cur = state + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = h_start; h < h_end; h++) { + for (int64_t h = h_start; h < h_end; h++) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < head_size; i++) { + for (int64_t i = 0; i < head_size; i++) { size_t t_h_i_offset = t_h_offset + i; size_t h_i_offset = h_offset + i; size_t h_2d_i_offset = h_2d_offset + i * h_stride; @@ -11804,7 +11804,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( // RWKV v6: different time_decay for each token. float time_decay_val = time_decay[t_h_i_offset]; - for (size_t j = 0; j < head_size; j ++) { + for (int64_t j = 0; j < head_size; j ++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j;