update the function to use appropriate types
This commit is contained in:
parent
bb0685fad5
commit
81cb301224
1 changed files with 19 additions and 19 deletions
|
@ -11646,22 +11646,22 @@ static void ggml_compute_forward_add_rel_pos(
|
||||||
static void ggml_compute_forward_rwkv_wkv6_f32(
|
static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
const size_t T = dst->src[1]->ne[3];
|
const int64_t T = dst->src[1]->ne[3];
|
||||||
const size_t C = dst->ne[0];
|
const int64_t C = dst->ne[0];
|
||||||
const size_t HEADS = dst->src[1]->ne[2];
|
const int64_t HEADS = dst->src[1]->ne[2];
|
||||||
const size_t n_seqs = dst->src[5]->ne[1];
|
const int64_t n_seqs = dst->src[5]->ne[1];
|
||||||
const size_t head_size = C / HEADS;
|
const int64_t head_size = C / HEADS;
|
||||||
|
|
||||||
float * dst_data = (float *) dst->data;
|
float * dst_data = (float *) dst->data;
|
||||||
float * state = ((float *) dst->data) + C * T;
|
float * state = ((float *) dst->data) + C * T;
|
||||||
|
|
||||||
if ((size_t)params->ith >= HEADS) {
|
if ((int64_t)params->ith >= HEADS) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t h_start = (HEADS * params->ith) / params->nth;
|
int64_t h_start = (HEADS * params->ith) / params->nth;
|
||||||
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
|
int64_t h_end = ((HEADS * (params->ith + 1)) / params->nth < HEADS) ?
|
||||||
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
|
(HEADS * (params->ith + 1)) / params->nth : HEADS;
|
||||||
|
|
||||||
float * k = (float *) dst->src[0]->data;
|
float * k = (float *) dst->src[0]->data;
|
||||||
float * v = (float *) dst->src[1]->data;
|
float * v = (float *) dst->src[1]->data;
|
||||||
|
@ -11708,20 +11708,20 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef WKV_VECTOR_SIZE
|
#ifdef WKV_VECTOR_SIZE
|
||||||
const size_t vec_count = head_size / WKV_VECTOR_SIZE;
|
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
|
||||||
|
|
||||||
for (size_t t = 0; t < T; t++) {
|
for (int64_t t = 0; t < T; t++) {
|
||||||
size_t t_offset = t * t_stride;
|
size_t t_offset = t * t_stride;
|
||||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||||
float * state_cur = state + state_offset;
|
float * state_cur = state + state_offset;
|
||||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||||
|
|
||||||
for (size_t h = h_start; h < h_end; h++) {
|
for (int64_t h = h_start; h < h_end; h++) {
|
||||||
size_t h_offset = h * h_stride;
|
size_t h_offset = h * h_stride;
|
||||||
size_t t_h_offset = t_offset + h_offset;
|
size_t t_h_offset = t_offset + h_offset;
|
||||||
size_t h_2d_offset = h * h_stride_2d;
|
size_t h_2d_offset = h * h_stride_2d;
|
||||||
|
|
||||||
for (size_t i = 0; i < head_size; i++) {
|
for (int64_t i = 0; i < head_size; i++) {
|
||||||
size_t t_h_i_offset = t_h_offset + i;
|
size_t t_h_i_offset = t_h_offset + i;
|
||||||
size_t h_i_offset = h_offset + i;
|
size_t h_i_offset = h_offset + i;
|
||||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||||
|
@ -11737,7 +11737,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
|
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
|
||||||
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
||||||
|
|
||||||
for (size_t j = 0; j < vec_count; j++) {
|
for (int64_t j = 0; j < vec_count; j++) {
|
||||||
size_t base_j = j * WKV_VECTOR_SIZE;
|
size_t base_j = j * WKV_VECTOR_SIZE;
|
||||||
size_t t_h_j_offset = t_h_offset + base_j;
|
size_t t_h_j_offset = t_h_offset + base_j;
|
||||||
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||||
|
@ -11763,7 +11763,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle remaining elements, this will not be used.
|
// Handle remaining elements, this will not be used.
|
||||||
for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
|
for (int64_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
|
||||||
size_t t_h_j_offset = t_h_offset + j;
|
size_t t_h_j_offset = t_h_offset + j;
|
||||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
float v_val = v[t_h_j_offset];
|
float v_val = v[t_h_j_offset];
|
||||||
|
@ -11782,18 +11782,18 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
// dst = r @ (time_faaaa * (k @ v) + state),
|
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||||
// state = time_decay * state + (k @ v),
|
// state = time_decay * state + (k @ v),
|
||||||
// recursive through each token
|
// recursive through each token
|
||||||
for (size_t t = 0; t < T; t++) {
|
for (int64_t t = 0; t < T; t++) {
|
||||||
size_t t_offset = t * t_stride;
|
size_t t_offset = t * t_stride;
|
||||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||||
float * state_cur = state + state_offset;
|
float * state_cur = state + state_offset;
|
||||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||||
|
|
||||||
for (size_t h = h_start; h < h_end; h++) {
|
for (int64_t h = h_start; h < h_end; h++) {
|
||||||
size_t h_offset = h * h_stride;
|
size_t h_offset = h * h_stride;
|
||||||
size_t t_h_offset = t_offset + h_offset;
|
size_t t_h_offset = t_offset + h_offset;
|
||||||
size_t h_2d_offset = h * h_stride_2d;
|
size_t h_2d_offset = h * h_stride_2d;
|
||||||
|
|
||||||
for (size_t i = 0; i < head_size; i++) {
|
for (int64_t i = 0; i < head_size; i++) {
|
||||||
size_t t_h_i_offset = t_h_offset + i;
|
size_t t_h_i_offset = t_h_offset + i;
|
||||||
size_t h_i_offset = h_offset + i;
|
size_t h_i_offset = h_offset + i;
|
||||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||||
|
@ -11804,7 +11804,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
// RWKV v6: different time_decay for each token.
|
// RWKV v6: different time_decay for each token.
|
||||||
float time_decay_val = time_decay[t_h_i_offset];
|
float time_decay_val = time_decay[t_h_i_offset];
|
||||||
|
|
||||||
for (size_t j = 0; j < head_size; j ++) {
|
for (int64_t j = 0; j < head_size; j ++) {
|
||||||
size_t t_h_j_offset = t_h_offset + j;
|
size_t t_h_j_offset = t_h_offset + j;
|
||||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue