update the function to use appropriate types

This commit is contained in:
Zhiyuan Li 2024-11-05 00:55:59 +11:00
parent bb0685fad5
commit 81cb301224

View file

@ -11646,22 +11646,22 @@ static void ggml_compute_forward_add_rel_pos(
static void ggml_compute_forward_rwkv_wkv6_f32( static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const size_t T = dst->src[1]->ne[3]; const int64_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0]; const int64_t C = dst->ne[0];
const size_t HEADS = dst->src[1]->ne[2]; const int64_t HEADS = dst->src[1]->ne[2];
const size_t n_seqs = dst->src[5]->ne[1]; const int64_t n_seqs = dst->src[5]->ne[1];
const size_t head_size = C / HEADS; const int64_t head_size = C / HEADS;
float * dst_data = (float *) dst->data; float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T; float * state = ((float *) dst->data) + C * T;
if ((size_t)params->ith >= HEADS) { if ((int64_t)params->ith >= HEADS) {
return; return;
} }
size_t h_start = (HEADS * params->ith) / params->nth; int64_t h_start = (HEADS * params->ith) / params->nth;
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ? int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ?
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS; (HEADS * (params->ith + 1)) / params->nth : HEADS;
float * k = (float *) dst->src[0]->data; float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data; float * v = (float *) dst->src[1]->data;
@ -11708,20 +11708,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
#endif #endif
#ifdef WKV_VECTOR_SIZE #ifdef WKV_VECTOR_SIZE
const size_t vec_count = head_size / WKV_VECTOR_SIZE; const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
for (size_t t = 0; t < T; t++) { for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride; size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs)); size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset; float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) { for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride; size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset; size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d; size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) { for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i; size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i; size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride; size_t h_2d_i_offset = h_2d_offset + i * h_stride;
@ -11737,7 +11737,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val); GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
for (size_t j = 0; j < vec_count; j++) { for (int64_t j = 0; j < vec_count; j++) {
size_t base_j = j * WKV_VECTOR_SIZE; size_t base_j = j * WKV_VECTOR_SIZE;
size_t t_h_j_offset = t_h_offset + base_j; size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
@ -11763,7 +11763,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
} }
// Handle remaining elements, this will not be used. // Handle remaining elements, this will not be used.
for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) { for (int64_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j; size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset]; float v_val = v[t_h_j_offset];
@ -11782,18 +11782,18 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
// dst = r @ (time_faaaa * (k @ v) + state), // dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v), // state = time_decay * state + (k @ v),
// recursive through each token // recursive through each token
for (size_t t = 0; t < T; t++) { for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride; size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs)); size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset; float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) { for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride; size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset; size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d; size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) { for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i; size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i; size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride; size_t h_2d_i_offset = h_2d_offset + i * h_stride;
@ -11804,7 +11804,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
// RWKV v6: different time_decay for each token. // RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset]; float time_decay_val = time_decay[t_h_i_offset];
for (size_t j = 0; j < head_size; j ++) { for (int64_t j = 0; j < head_size; j ++) {
size_t t_h_j_offset = t_h_offset + j; size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j;