fix: update changes to upstream

This commit is contained in:
Zhiyuan Li 2024-11-04 22:17:12 +11:00
parent 5f792141c5
commit 61c665b7f1

View file

@ -11641,24 +11641,27 @@ static void ggml_compute_forward_add_rel_pos(
} }
} }
// ggml_compute_forward_rwkv_wkv // ggml_compute_forward_rwkv_wkv6
static void ggml_compute_forward_rwkv_wkv_f32( static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
const size_t T = dst->src[1]->ne[3]; const size_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0]; const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2]; const size_t HEADS = dst->src[1]->ne[2];
const size_t n_seqs = dst->src[5]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1];
const size_t head_size = C / HEADS;
float * dst_data = (float *) dst->data; float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T; float * state = ((float *) dst->data) + C * T;
if (params->ith != 0) { if ((size_t)params->ith >= HEADS) {
return; return;
} }
memset(dst_data, 0, T * C * sizeof(float)); size_t h_start = (HEADS * params->ith) / params->nth;
size_t h_end = ((HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth < HEADS) ?
(HEADS * (size_t)(params->ith + 1)) / (size_t)params->nth : HEADS;
float * k = (float *) dst->src[0]->data; float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data; float * v = (float *) dst->src[1]->data;
@ -11666,54 +11669,161 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float * time_faaaa = (float *) dst->src[3]->data; float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data; float * time_decay = (float *) dst->src[4]->data;
size_t t_stride = H * (C / H); size_t t_stride = HEADS * head_size;
size_t h_stride = C / H; size_t h_stride = C / HEADS;
size_t h_stride_2d = (C / H) * (C / H); size_t h_stride_2d = head_size * head_size;
// basically fused operations: if (params->ith == 0) {
// dst = r @ (time_faaaa * (k @ v) + state), memset(dst_data, 0, T * C * sizeof(float));
// state = time_decay * state + (k @ v), }
// recursive through each token ggml_barrier(params->threadpool);
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < C / H; i++) { #ifdef __AVX2__
size_t t_h_i_offset = t_h_offset + i; #define GGML_F32X GGML_F32x8
size_t h_i_offset = h_offset + i; #define GGML_F32X_SET1 GGML_F32x8_SET1
size_t h_2d_i_offset = h_2d_offset + i * h_stride; #define GGML_F32X_LOAD GGML_F32x8_LOAD
#define GGML_F32X_STORE GGML_F32x8_STORE
#define GGML_F32X_MUL GGML_F32x8_MUL
#define GGML_F32X_FMA GGML_F32x8_FMA
#define VECTOR_SIZE 8
#elif __AVX512F__
#define GGML_F32X GGML_F32x16
#define GGML_F32X_SET1 GGML_F32x16_SET1
#define GGML_F32X_LOAD GGML_F32x16_LOAD
#define GGML_F32X_STORE GGML_F32x16_STORE
#define GGML_F32X_MUL GGML_F32x16_MUL
#define GGML_F32X_FMA GGML_F32x16_FMA
#define WKV_VECTOR_SIZE 16
#elif defined(__ARM_NEON) && defined(__aarch64__)
#define GGML_F32X GGML_F32x4
#define GGML_F32X_SET1 GGML_F32x4_SET1
#define GGML_F32X_LOAD GGML_F32x4_LOAD
#define GGML_F32X_STORE GGML_F32x4_STORE
#define GGML_F32X_MUL GGML_F32x4_MUL
#define GGML_F32X_FMA GGML_F32x4_FMA
#define WKV_VECTOR_SIZE 4
#endif
float k_val = k[t_h_i_offset]; #ifdef WKV_VECTOR_SIZE
float r_val = r[t_h_i_offset]; const size_t vec_count = head_size / WKV_VECTOR_SIZE;
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];
for (size_t j = 0; j < C / H; j ++) { for (size_t t = 0; t < T; t++) {
size_t t_h_j_offset = t_h_offset + j; size_t t_offset = t * t_stride;
size_t h_2d_i_j_offset = h_2d_i_offset + j; size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
float v_val = v[t_h_j_offset]; for (size_t h = h_start; h < h_end; h++) {
float kv_val = v_val * k_val; size_t h_offset = h * h_stride;
float prev_state_val = state_prev[h_2d_i_j_offset]; size_t t_h_offset = t_offset + h_offset;
float temp_val = kv_val * time_faaaa_val + prev_state_val; size_t h_2d_offset = h * h_stride_2d;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; for (size_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
float time_decay_val = time_decay[t_h_i_offset];
// Broadcast scalar values to vectors
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
for (size_t j = 0; j < vec_count; j++) {
size_t base_j = j * WKV_VECTOR_SIZE;
size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
// Load x elements at once
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
// Compute kv = v * k
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
// Compute temp = kv * time_faaaa + prev_state
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
// Update dst: dst += temp * r
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
// Update state: state = prev_state * time_decay + kv
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
}
// Handle remaining elements, this will not be used.
for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
} }
} }
} }
}
#else
// basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v),
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];
for (size_t j = 0; j < head_size; j ++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
#endif
} }
static void ggml_compute_forward_rwkv_wkv(
static void ggml_compute_forward_rwkv_wkv6(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
@ -11722,7 +11832,7 @@ static void ggml_compute_forward_rwkv_wkv(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_rwkv_wkv_f32(params, dst); ggml_compute_forward_rwkv_wkv6_f32(params, dst);
} break; } break;
default: default:
{ {
@ -12474,9 +12584,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_add_rel_pos(params, tensor); ggml_compute_forward_add_rel_pos(params, tensor);
} break; } break;
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV6:
{ {
ggml_compute_forward_rwkv_wkv(params, tensor); ggml_compute_forward_rwkv_wkv6(params, tensor);
} break; } break;
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
{ {
@ -12774,7 +12884,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_PART: case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART: case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV6:
case GGML_OP_MAP_UNARY: case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY: case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: case GGML_OP_MAP_CUSTOM1_F32: