wkv6: drop armv9 and tranfer to GGML style

This commit is contained in:
Zhiyuan Li 2024-11-03 23:52:25 +11:00
parent 042c3e0fd3
commit 811aa872d6

View file

@ -16645,360 +16645,143 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
#ifdef __AVX2__ #ifdef __AVX2__
// AVX2 uses 256-bit vectors = 8 float32 #define GGML_F32X GGML_F32x8
const int vec_size = 8; #define GGML_F32X_SET1 GGML_F32x8_SET1
const size_t vec_count = head_size / vec_size; #define GGML_F32X_LOAD GGML_F32x8_LOAD
#define GGML_F32X_STORE GGML_F32x8_STORE
for (size_t t = 0; t < T; t++) { #define GGML_F32X_MUL GGML_F32x8_MUL
size_t t_offset = t * t_stride; #define GGML_F32X_FMA GGML_F32x8_FMA
size_t state_offset = head_size * C * (t / (T / n_seqs)); #define VECTOR_SIZE 8
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
float time_decay_val = time_decay[t_h_i_offset];
// Broadcast scalar values to vectors
__m256 k_vec = _mm256_set1_ps(k_val);
__m256 r_vec = _mm256_set1_ps(r_val);
__m256 time_faaaa_vec = _mm256_set1_ps(time_faaaa_val);
__m256 time_decay_vec = _mm256_set1_ps(time_decay_val);
// Vector processing for chunks of 8 floats
for (size_t j = 0; j < vec_count; j++) {
size_t base_j = j * vec_size;
size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
// Load 8 elements at once
__m256 v_vec = _mm256_loadu_ps(&v[t_h_j_offset]);
__m256 prev_state_vec = _mm256_loadu_ps(&state_prev[h_2d_i_j_offset]);
__m256 dst_vec = _mm256_loadu_ps(&dst_data[t_h_j_offset]);
// Compute kv = v * k
__m256 kv_vec = _mm256_mul_ps(v_vec, k_vec);
// Compute temp = kv * time_faaaa + prev_state
__m256 kv_time_vec = _mm256_mul_ps(kv_vec, time_faaaa_vec);
__m256 temp_vec = _mm256_add_ps(kv_time_vec, prev_state_vec);
// Update dst: dst += temp * r
__m256 result_vec = _mm256_mul_ps(temp_vec, r_vec);
dst_vec = _mm256_add_ps(dst_vec, result_vec);
_mm256_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
// Update state: state = prev_state * time_decay + kv
__m256 decay_state_vec = _mm256_mul_ps(prev_state_vec, time_decay_vec);
__m256 new_state_vec = _mm256_add_ps(decay_state_vec, kv_vec);
_mm256_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
}
// Handle remaining elements, this will not be used.
for (size_t j = vec_count * vec_size; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
#elif __AVX512F__ #elif __AVX512F__
// AVX-512 uses 512-bit vectors = 16 float32 #define GGML_F32X GGML_F32x16
const int vec_size = 16; #define GGML_F32X_SET1 GGML_F32x16_SET1
const size_t vec_count = head_size / vec_size; #define GGML_F32X_LOAD GGML_F32x16_LOAD
const size_t vec_remain = head_size % vec_size; #define GGML_F32X_STORE GGML_F32x16_STORE
#define GGML_F32X_MUL GGML_F32x16_MUL
#define GGML_F32X_FMA GGML_F32x16_FMA
#define WKV_VECTOR_SIZE 16
#elif defined(__ARM_NEON) && defined(__aarch64__)
#define GGML_F32X GGML_F32x4
#define GGML_F32X_SET1 GGML_F32x4_SET1
#define GGML_F32X_LOAD GGML_F32x4_LOAD
#define GGML_F32X_STORE GGML_F32x4_STORE
#define GGML_F32X_MUL GGML_F32x4_MUL
#define GGML_F32X_FMA GGML_F32x4_FMA
#define WKV_VECTOR_SIZE 4
#endif
for (size_t t = 0; t < T; t++) { #ifdef WKV_VECTOR_SIZE
size_t t_offset = t * t_stride; const size_t vec_count = head_size / WKV_VECTOR_SIZE;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) { for (size_t t = 0; t < T; t++) {
size_t h_offset = h * h_stride; size_t t_offset = t * t_stride;
size_t t_h_offset = t_offset + h_offset; size_t state_offset = head_size * C * (t / (T / n_seqs));
size_t h_2d_offset = h * h_stride_2d; float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t i = 0; i < head_size; i++) { for (size_t h = h_start; h < h_end; h++) {
size_t t_h_i_offset = t_h_offset + i; size_t h_offset = h * h_stride;
size_t h_i_offset = h_offset + i; size_t t_h_offset = t_offset + h_offset;
size_t h_2d_i_offset = h_2d_offset + i * h_stride; size_t h_2d_offset = h * h_stride_2d;
// Load scalar values for (size_t i = 0; i < head_size; i++) {
float k_val = k[t_h_i_offset]; size_t t_h_i_offset = t_h_offset + i;
float r_val = r[t_h_i_offset]; size_t h_i_offset = h_offset + i;
float time_faaaa_val = time_faaaa[h_i_offset]; size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float time_decay_val = time_decay[t_h_i_offset];
// Broadcast scalar values to ZMM registers (512-bit) float k_val = k[t_h_i_offset];
__m512 k_vec = _mm512_set1_ps(k_val); float r_val = r[t_h_i_offset];
__m512 r_vec = _mm512_set1_ps(r_val); float time_faaaa_val = time_faaaa[h_i_offset];
__m512 time_faaaa_vec = _mm512_set1_ps(time_faaaa_val); float time_decay_val = time_decay[t_h_i_offset];
__m512 time_decay_vec = _mm512_set1_ps(time_decay_val);
// Use prefetch to reduce cache misses // Broadcast scalar values to vectors
#define PREFETCH_OFFSET 2 GGML_F32X k_vec = GGML_F32X_SET1(k_val);
if (i + PREFETCH_OFFSET < head_size) { GGML_F32X r_vec = GGML_F32X_SET1(r_val);
_mm_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], _MM_HINT_T0); GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
_mm_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], _MM_HINT_T0); GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
}
// Vector processing for chunks of 16 floats for (size_t j = 0; j < vec_count; j++) {
for (size_t j = 0; j < vec_count; j++) { size_t base_j = j * WKV_VECTOR_SIZE;
size_t base_j = j * vec_size; size_t t_h_j_offset = t_h_offset + base_j;
size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
// Load 16 elements at once // Load x elements at once
__m512 v_vec = _mm512_loadu_ps(&v[t_h_j_offset]); GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
__m512 prev_state_vec = _mm512_loadu_ps(&state_prev[h_2d_i_j_offset]); GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
__m512 dst_vec = _mm512_loadu_ps(&dst_data[t_h_j_offset]); GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
// Compute kv = v * k using FMA // Compute kv = v * k
__m512 kv_vec = _mm512_mul_ps(v_vec, k_vec); GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
// Compute temp = kv * time_faaaa + prev_state using FMA // Compute temp = kv * time_faaaa + prev_state
__m512 temp_vec = _mm512_fmadd_ps(kv_vec, time_faaaa_vec, prev_state_vec); GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
// Update dst: dst += temp * r using FMA
dst_vec = _mm512_fmadd_ps(temp_vec, r_vec, dst_vec);
_mm512_storeu_ps(&dst_data[t_h_j_offset], dst_vec);
// Update state: state = prev_state * time_decay + kv using FMA
__m512 new_state_vec = _mm512_fmadd_ps(prev_state_vec, time_decay_vec, kv_vec);
_mm512_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec);
}
// Handle remaining elements, this will not be used.
for (size_t j = vec_count * vec_size; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
}
}
}
#elif __ARM_FEATURE_SVE
// Get vector length for this CPU
const size_t vec_size = svcntw();
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
float time_decay_val = time_decay[t_h_i_offset];
// Create predicate for active lanes
svbool_t pg = svwhilelt_b32(0, head_size);
// Process vectors until done
size_t j = 0;
while (svptest_first(svptrue_b32(), pg)) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
// Load vectors
svfloat32_t v_vec = svld1_f32(pg, &v[t_h_j_offset]);
svfloat32_t prev_state_vec = svld1_f32(pg, &state_prev[h_2d_i_j_offset]);
svfloat32_t dst_vec = svld1_f32(pg, &dst_data[t_h_j_offset]);
// Compute kv = v * k
svfloat32_t kv_vec = svmul_n_f32_x(pg, v_vec, k_val);
// Compute temp = kv * time_faaaa + prev_state
svfloat32_t temp_vec = svmad_n_f32_x(pg, kv_vec, time_faaaa_val, prev_state_vec);
// Update dst: dst += temp * r
svfloat32_t result_vec = svmad_n_f32_x(pg, temp_vec, r_val, dst_vec);
svst1_f32(pg, &dst_data[t_h_j_offset], result_vec);
// Update state: state = prev_state * time_decay + kv
svfloat32_t new_state_vec = svmad_n_f32_x(pg, prev_state_vec, time_decay_val, kv_vec);
svst1_f32(pg, &state_cur[h_2d_i_j_offset], new_state_vec);
j += vec_size;
pg = svwhilelt_b32(j, head_size);
}
}
}
}
#elif __ARM_NEON
// NEON uses 128-bit vectors = 4 float32s
const int vec_size = 4;
const size_t vec_count = head_size / vec_size;
for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset];
float time_decay_val = time_decay[t_h_i_offset];
// Broadcast scalar values to vectors
float32x4_t k_vec = vdupq_n_f32(k_val);
float32x4_t r_vec = vdupq_n_f32(r_val);
float32x4_t time_faaaa_vec = vdupq_n_f32(time_faaaa_val);
float32x4_t time_decay_vec = vdupq_n_f32(time_decay_val);
// Use prefetch to reduce cache misses
#ifdef __ARM_FEATURE_PREFETCH
#define PREFETCH_OFFSET 2
if (i + PREFETCH_OFFSET < head_size) {
__builtin_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], 0, 3);
__builtin_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], 0, 3);
}
#endif
// Vector processing for chunks of 4 floats
for (size_t j = 0; j < vec_count; j++) {
size_t base_j = j * vec_size;
size_t t_h_j_offset = t_h_offset + base_j;
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
// Load 4 elements at once
float32x4_t v_vec = vld1q_f32(&v[t_h_j_offset]);
float32x4_t prev_state_vec = vld1q_f32(&state_prev[h_2d_i_j_offset]);
float32x4_t dst_vec = vld1q_f32(&dst_data[t_h_j_offset]);
// Compute kv = v * k
float32x4_t kv_vec = vmulq_f32(v_vec, k_vec);
// Compute temp = kv * time_faaaa + prev_state using FMA
#ifdef __ARM_FEATURE_FMA
float32x4_t temp_vec = vfmaq_f32(prev_state_vec, kv_vec, time_faaaa_vec);
// Update dst: dst += temp * r // Update dst: dst += temp * r
dst_vec = vfmaq_f32(dst_vec, temp_vec, r_vec); dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
// Update state: state = prev_state * time_decay + kv // Update state: state = prev_state * time_decay + kv
float32x4_t new_state_vec = vfmaq_f32(kv_vec, prev_state_vec, time_decay_vec); GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
#else GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
float32x4_t kv_time = vmulq_f32(kv_vec, time_faaaa_vec); }
float32x4_t temp_vec = vaddq_f32(kv_time, prev_state_vec);
float32x4_t result_vec = vmulq_f32(temp_vec, r_vec);
dst_vec = vaddq_f32(dst_vec, result_vec);
float32x4_t decay_state_vec = vmulq_f32(prev_state_vec, time_decay_vec);
float32x4_t new_state_vec = vaddq_f32(decay_state_vec, kv_vec);
#endif
vst1q_f32(&dst_data[t_h_j_offset], dst_vec); // Handle remaining elements, this will not be used.
vst1q_f32(&state_cur[h_2d_i_j_offset], new_state_vec); for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
} size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
// Handle remaining elements float v_val = v[t_h_j_offset];
for (size_t j = vec_count * vec_size; j < head_size; j++) { float kv_val = v_val * k_val;
size_t t_h_j_offset = t_h_offset + j; float prev_state_val = state_prev[h_2d_i_j_offset];
size_t h_2d_i_j_offset = h_2d_i_offset + j; float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
float v_val = v[t_h_j_offset]; state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
float kv_val = v_val * k_val; }
float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
} }
} }
} }
}
#else #else
// basically fused operations: // basically fused operations:
// dst = r @ (time_faaaa * (k @ v) + state), // dst = r @ (time_faaaa * (k @ v) + state),
// state = time_decay * state + (k @ v), // state = time_decay * state + (k @ v),
// recursive through each token // recursive through each token
for (size_t t = 0; t < T; t++) { for (size_t t = 0; t < T; t++) {
size_t t_offset = t * t_stride; size_t t_offset = t * t_stride;
size_t state_offset = head_size * C * (t / (T / n_seqs)); size_t state_offset = head_size * C * (t / (T / n_seqs));
float * state_cur = state + state_offset; float * state_cur = state + state_offset;
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
for (size_t h = h_start; h < h_end; h++) { for (size_t h = h_start; h < h_end; h++) {
size_t h_offset = h * h_stride; size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset; size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d; size_t h_2d_offset = h * h_stride_2d;
for (size_t i = 0; i < head_size; i++) { for (size_t i = 0; i < head_size; i++) {
size_t t_h_i_offset = t_h_offset + i; size_t t_h_i_offset = t_h_offset + i;
size_t h_i_offset = h_offset + i; size_t h_i_offset = h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride; size_t h_2d_i_offset = h_2d_offset + i * h_stride;
float k_val = k[t_h_i_offset]; float k_val = k[t_h_i_offset];
float r_val = r[t_h_i_offset]; float r_val = r[t_h_i_offset];
float time_faaaa_val = time_faaaa[h_i_offset]; float time_faaaa_val = time_faaaa[h_i_offset];
// RWKV v6: different time_decay for each token. // RWKV v6: different time_decay for each token.
float time_decay_val = time_decay[t_h_i_offset]; float time_decay_val = time_decay[t_h_i_offset];
for (size_t j = 0; j < head_size; j ++) { for (size_t j = 0; j < head_size; j ++) {
size_t t_h_j_offset = t_h_offset + j; size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset]; float v_val = v[t_h_j_offset];
float kv_val = v_val * k_val; float kv_val = v_val * k_val;
float prev_state_val = state_prev[h_2d_i_j_offset]; float prev_state_val = state_prev[h_2d_i_j_offset];
float temp_val = kv_val * time_faaaa_val + prev_state_val; float temp_val = kv_val * time_faaaa_val + prev_state_val;
dst_data[t_h_j_offset] += temp_val * r_val; dst_data[t_h_j_offset] += temp_val * r_val;
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
}
} }
} }
} }
}
#endif #endif
} }