Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration (#10133)

* rwkv6: rename to wkv6

* rwkv6: support avx2 avx512 armv8 armv9

* rwkv6: update cuda file name

* rwkv6: rename params

* wkv on sycl

* sycl: add some ops

* sycl: Enhance OP support judgment

* wkv6: drop armv9 and tranfer to GGML style

ggml-ci

* sync : ggml

* update the function to use appropriate types

* fix define error

* Update ggml/src/ggml-cpu.c

* add appropriate asserts

* move element-wise functions outside

* put the declaration outside the loop

* rewrite to be more inline with the common pattern for distributing threads

* use recommended way GGML_TENSOR_LOCALS

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
Co-authored-by: Plamen Minev <pacominev@gmail.com>
Co-authored-by: Yuri Khrustalev <ykhrustalev@users.noreply.github.com>
Co-authored-by: Meng, Hengyu <airdldl@163.com>
This commit is contained in:
Zhiyuan Li 2024-11-07 18:19:10 +11:00 committed by GitHub
parent 5c333e0140
commit 3bcd40b3c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1977 additions and 1027 deletions

View file

@ -11642,24 +11642,30 @@ static void ggml_compute_forward_add_rel_pos(
}
}
// ggml_compute_forward_rwkv_wkv
// ggml_compute_forward_rwkv_wkv6
static void ggml_compute_forward_rwkv_wkv_f32(
static void ggml_compute_forward_rwkv_wkv6_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const size_t T = dst->src[1]->ne[3];
const size_t C = dst->ne[0];
const size_t H = dst->src[1]->ne[2];
const size_t n_seqs = dst->src[5]->ne[1];
const int64_t T = dst->src[1]->ne[3];
const int64_t C = dst->ne[0];
const int64_t HEADS = dst->src[1]->ne[2];
const int64_t n_seqs = dst->src[5]->ne[1];
const int64_t head_size = C / HEADS;
float * dst_data = (float *) dst->data;
float * state = ((float *) dst->data) + C * T;
if (params->ith != 0) {
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
memset(dst_data, 0, T * C * sizeof(float));
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -11667,54 +11673,160 @@ static void ggml_compute_forward_rwkv_wkv_f32(
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
size_t t_stride = H * (C / H);
size_t t_stride = HEADS * head_size; // Same to C
size_t h_stride = C / H;
size_t h_stride_2d = (C / H) * (C / H);
size_t h_stride = C / HEADS;
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
size_t h_stride_2d = head_size * head_size;
// basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v),
// recursive through each token
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
if (ith == 0) {
memset(dst_data, 0, T * C * sizeof(float));
}
ggml_barrier(params->threadpool);
for (size_t h = 0; h < H; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < C / H; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
#if defined(__AVX__) && !defined(__AVX512F__)
#define GGML_F32X GGML_F32x8
#define GGML_F32X_SET1 GGML_F32x8_SET1
#define GGML_F32X_LOAD GGML_F32x8_LOAD
#define GGML_F32X_STORE GGML_F32x8_STORE
#define GGML_F32X_MUL GGML_F32x8_MUL
#define GGML_F32X_FMA GGML_F32x8_FMA
#define WKV_VECTOR_SIZE 8
#elif defined(__AVX512F__)
#define GGML_F32X GGML_F32x16
#define GGML_F32X_SET1 GGML_F32x16_SET1
#define GGML_F32X_LOAD GGML_F32x16_LOAD
#define GGML_F32X_STORE GGML_F32x16_STORE
#define GGML_F32X_MUL GGML_F32x16_MUL
#define GGML_F32X_FMA GGML_F32x16_FMA
#define WKV_VECTOR_SIZE 16
#elif defined(__ARM_NEON) && defined(__aarch64__)
#define GGML_F32X GGML_F32x4
#define GGML_F32X_SET1 GGML_F32x4_SET1
#define GGML_F32X_LOAD GGML_F32x4_LOAD
#define GGML_F32X_STORE GGML_F32x4_STORE
#define GGML_F32X_MUL GGML_F32x4_MUL
#define GGML_F32X_FMA GGML_F32x4_FMA
#define WKV_VECTOR_SIZE 4
#endif
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];
#ifdef WKV_VECTOR_SIZE
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
for (size_t j = 0; j < C / H; j ++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
float time_decay_val = time_decay[t_h_i_offset];
// Broadcast scalar values to vectors
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
for (int64_t j = 0; j < vec_count; j++) {
size_t base_j = j * WKV_VECTOR_SIZE;
size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
// Load x elements at once
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
// Compute kv = v * k
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
// Compute temp = kv * time_faaaa + prev_state
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
// Update dst: dst += temp * r
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
// Update state: state = prev_state * time_decay + kv
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
}
// Handle remaining elements, this will not be used.
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
}
#else
// basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v),
// recursive through each token
for (int64_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (int64_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (int64_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset];
for (int64_t j = 0; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
#endif
}
static void ggml_compute_forward_rwkv_wkv(
static void ggml_compute_forward_rwkv_wkv6(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rwkv_wkv_f32(params, dst);
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
} break;
default:
{
@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add_rel_pos(params, tensor);
} break;
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_WKV6:
{
ggml_compute_forward_rwkv_wkv(params, tensor);
ggml_compute_forward_rwkv_wkv6(params, tensor);
} break;
case GGML_OP_MAP_UNARY:
{
@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
case GGML_OP_RWKV_WKV:
case GGML_OP_RWKV_WKV6:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1_F32: