diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 84f0e8201..347e1289e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -16645,360 +16645,143 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #ifdef __AVX2__ - // AVX2 uses 256-bit vectors = 8 float32 - const int vec_size = 8; - const size_t vec_count = head_size / vec_size; - - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - - for (size_t h = h_start; h < h_end; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < head_size; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - float time_decay_val = time_decay[t_h_i_offset]; - - // Broadcast scalar values to vectors - __m256 k_vec = _mm256_set1_ps(k_val); - __m256 r_vec = _mm256_set1_ps(r_val); - __m256 time_faaaa_vec = _mm256_set1_ps(time_faaaa_val); - __m256 time_decay_vec = _mm256_set1_ps(time_decay_val); - - // Vector processing for chunks of 8 floats - for (size_t j = 0; j < vec_count; j++) { - size_t base_j = j * vec_size; - size_t t_h_j_offset = t_h_offset + base_j; - size_t h_2d_i_j_offset = h_2d_i_offset + base_j; - - // Load 8 elements at once - __m256 v_vec = _mm256_loadu_ps(&v[t_h_j_offset]); - __m256 prev_state_vec = _mm256_loadu_ps(&state_prev[h_2d_i_j_offset]); - __m256 dst_vec = _mm256_loadu_ps(&dst_data[t_h_j_offset]); - - // Compute kv = v * k - __m256 kv_vec = _mm256_mul_ps(v_vec, k_vec); - - // Compute temp = kv * time_faaaa + prev_state - __m256 kv_time_vec = _mm256_mul_ps(kv_vec, time_faaaa_vec); - __m256 temp_vec = _mm256_add_ps(kv_time_vec, prev_state_vec); - - // Update dst: dst += temp * r - __m256 result_vec = _mm256_mul_ps(temp_vec, r_vec); - dst_vec = _mm256_add_ps(dst_vec, result_vec); - _mm256_storeu_ps(&dst_data[t_h_j_offset], dst_vec); - - // Update state: state = prev_state * time_decay + kv - __m256 decay_state_vec = _mm256_mul_ps(prev_state_vec, time_decay_vec); - __m256 new_state_vec = _mm256_add_ps(decay_state_vec, kv_vec); - _mm256_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec); - } - - // Handle remaining elements, this will not be used. - for (size_t j = vec_count * vec_size; j < head_size; j++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; - } - } - } - } - + #define GGML_F32X GGML_F32x8 + #define GGML_F32X_SET1 GGML_F32x8_SET1 + #define GGML_F32X_LOAD GGML_F32x8_LOAD + #define GGML_F32X_STORE GGML_F32x8_STORE + #define GGML_F32X_MUL GGML_F32x8_MUL + #define GGML_F32X_FMA GGML_F32x8_FMA + #define VECTOR_SIZE 8 #elif __AVX512F__ - // AVX-512 uses 512-bit vectors = 16 float32 - const int vec_size = 16; - const size_t vec_count = head_size / vec_size; - const size_t vec_remain = head_size % vec_size; + #define GGML_F32X GGML_F32x16 + #define GGML_F32X_SET1 GGML_F32x16_SET1 + #define GGML_F32X_LOAD GGML_F32x16_LOAD + #define GGML_F32X_STORE GGML_F32x16_STORE + #define GGML_F32X_MUL GGML_F32x16_MUL + #define GGML_F32X_FMA GGML_F32x16_FMA + #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_NEON) && defined(__aarch64__) + #define GGML_F32X GGML_F32x4 + #define GGML_F32X_SET1 GGML_F32x4_SET1 + #define GGML_F32X_LOAD GGML_F32x4_LOAD + #define GGML_F32X_STORE GGML_F32x4_STORE + #define GGML_F32X_MUL GGML_F32x4_MUL + #define GGML_F32X_FMA GGML_F32x4_FMA + #define WKV_VECTOR_SIZE 4 + #endif - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + #ifdef WKV_VECTOR_SIZE + const size_t vec_count = head_size / WKV_VECTOR_SIZE; - for (size_t h = h_start; h < h_end; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t i = 0; i < head_size; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; + for (size_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; - // Load scalar values - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - float time_decay_val = time_decay[t_h_i_offset]; + for (size_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; - // Broadcast scalar values to ZMM registers (512-bit) - __m512 k_vec = _mm512_set1_ps(k_val); - __m512 r_vec = _mm512_set1_ps(r_val); - __m512 time_faaaa_vec = _mm512_set1_ps(time_faaaa_val); - __m512 time_decay_vec = _mm512_set1_ps(time_decay_val); + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + float time_decay_val = time_decay[t_h_i_offset]; - // Use prefetch to reduce cache misses - #define PREFETCH_OFFSET 2 - if (i + PREFETCH_OFFSET < head_size) { - _mm_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], _MM_HINT_T0); - _mm_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], _MM_HINT_T0); - } + // Broadcast scalar values to vectors + GGML_F32X k_vec = GGML_F32X_SET1(k_val); + GGML_F32X r_vec = GGML_F32X_SET1(r_val); + GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val); + GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); - // Vector processing for chunks of 16 floats - for (size_t j = 0; j < vec_count; j++) { - size_t base_j = j * vec_size; - size_t t_h_j_offset = t_h_offset + base_j; - size_t h_2d_i_j_offset = h_2d_i_offset + base_j; + for (size_t j = 0; j < vec_count; j++) { + size_t base_j = j * WKV_VECTOR_SIZE; + size_t t_h_j_offset = t_h_offset + base_j; + size_t h_2d_i_j_offset = h_2d_i_offset + base_j; - // Load 16 elements at once - __m512 v_vec = _mm512_loadu_ps(&v[t_h_j_offset]); - __m512 prev_state_vec = _mm512_loadu_ps(&state_prev[h_2d_i_j_offset]); - __m512 dst_vec = _mm512_loadu_ps(&dst_data[t_h_j_offset]); + // Load x elements at once + GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]); + GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); + GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]); - // Compute kv = v * k using FMA - __m512 kv_vec = _mm512_mul_ps(v_vec, k_vec); + // Compute kv = v * k + GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec); - // Compute temp = kv * time_faaaa + prev_state using FMA - __m512 temp_vec = _mm512_fmadd_ps(kv_vec, time_faaaa_vec, prev_state_vec); + // Compute temp = kv * time_faaaa + prev_state + GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec); - // Update dst: dst += temp * r using FMA - dst_vec = _mm512_fmadd_ps(temp_vec, r_vec, dst_vec); - _mm512_storeu_ps(&dst_data[t_h_j_offset], dst_vec); - - // Update state: state = prev_state * time_decay + kv using FMA - __m512 new_state_vec = _mm512_fmadd_ps(prev_state_vec, time_decay_vec, kv_vec); - _mm512_storeu_ps(&state_cur[h_2d_i_j_offset], new_state_vec); - } - - // Handle remaining elements, this will not be used. - for (size_t j = vec_count * vec_size; j < head_size; j++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; - } - } - } - } - - - #elif __ARM_FEATURE_SVE - // Get vector length for this CPU - const size_t vec_size = svcntw(); - - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - - for (size_t h = h_start; h < h_end; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < head_size; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - float time_decay_val = time_decay[t_h_i_offset]; - - // Create predicate for active lanes - svbool_t pg = svwhilelt_b32(0, head_size); - - // Process vectors until done - size_t j = 0; - while (svptest_first(svptrue_b32(), pg)) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - // Load vectors - svfloat32_t v_vec = svld1_f32(pg, &v[t_h_j_offset]); - svfloat32_t prev_state_vec = svld1_f32(pg, &state_prev[h_2d_i_j_offset]); - svfloat32_t dst_vec = svld1_f32(pg, &dst_data[t_h_j_offset]); - - // Compute kv = v * k - svfloat32_t kv_vec = svmul_n_f32_x(pg, v_vec, k_val); - - // Compute temp = kv * time_faaaa + prev_state - svfloat32_t temp_vec = svmad_n_f32_x(pg, kv_vec, time_faaaa_val, prev_state_vec); - - // Update dst: dst += temp * r - svfloat32_t result_vec = svmad_n_f32_x(pg, temp_vec, r_val, dst_vec); - svst1_f32(pg, &dst_data[t_h_j_offset], result_vec); - - // Update state: state = prev_state * time_decay + kv - svfloat32_t new_state_vec = svmad_n_f32_x(pg, prev_state_vec, time_decay_val, kv_vec); - svst1_f32(pg, &state_cur[h_2d_i_j_offset], new_state_vec); - - j += vec_size; - pg = svwhilelt_b32(j, head_size); - } - } - } - } - - #elif __ARM_NEON - // NEON uses 128-bit vectors = 4 float32s - const int vec_size = 4; - const size_t vec_count = head_size / vec_size; - - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - - for (size_t h = h_start; h < h_end; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < head_size; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - float time_decay_val = time_decay[t_h_i_offset]; - - // Broadcast scalar values to vectors - float32x4_t k_vec = vdupq_n_f32(k_val); - float32x4_t r_vec = vdupq_n_f32(r_val); - float32x4_t time_faaaa_vec = vdupq_n_f32(time_faaaa_val); - float32x4_t time_decay_vec = vdupq_n_f32(time_decay_val); - - // Use prefetch to reduce cache misses - #ifdef __ARM_FEATURE_PREFETCH - #define PREFETCH_OFFSET 2 - if (i + PREFETCH_OFFSET < head_size) { - __builtin_prefetch(&v[t_h_offset + i + PREFETCH_OFFSET], 0, 3); - __builtin_prefetch(&state_prev[h_2d_i_offset + PREFETCH_OFFSET * h_stride], 0, 3); - } - #endif - - // Vector processing for chunks of 4 floats - for (size_t j = 0; j < vec_count; j++) { - size_t base_j = j * vec_size; - size_t t_h_j_offset = t_h_offset + base_j; - size_t h_2d_i_j_offset = h_2d_i_offset + base_j; - - // Load 4 elements at once - float32x4_t v_vec = vld1q_f32(&v[t_h_j_offset]); - float32x4_t prev_state_vec = vld1q_f32(&state_prev[h_2d_i_j_offset]); - float32x4_t dst_vec = vld1q_f32(&dst_data[t_h_j_offset]); - - // Compute kv = v * k - float32x4_t kv_vec = vmulq_f32(v_vec, k_vec); - - // Compute temp = kv * time_faaaa + prev_state using FMA - #ifdef __ARM_FEATURE_FMA - float32x4_t temp_vec = vfmaq_f32(prev_state_vec, kv_vec, time_faaaa_vec); // Update dst: dst += temp * r - dst_vec = vfmaq_f32(dst_vec, temp_vec, r_vec); + dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec); + GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); + // Update state: state = prev_state * time_decay + kv - float32x4_t new_state_vec = vfmaq_f32(kv_vec, prev_state_vec, time_decay_vec); - #else - float32x4_t kv_time = vmulq_f32(kv_vec, time_faaaa_vec); - float32x4_t temp_vec = vaddq_f32(kv_time, prev_state_vec); - float32x4_t result_vec = vmulq_f32(temp_vec, r_vec); - dst_vec = vaddq_f32(dst_vec, result_vec); - float32x4_t decay_state_vec = vmulq_f32(prev_state_vec, time_decay_vec); - float32x4_t new_state_vec = vaddq_f32(decay_state_vec, kv_vec); - #endif + GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec); + GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec); + } - vst1q_f32(&dst_data[t_h_j_offset], dst_vec); - vst1q_f32(&state_cur[h_2d_i_j_offset], new_state_vec); - } - - // Handle remaining elements - for (size_t j = vec_count * vec_size; j < head_size; j++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + // Handle remaining elements, this will not be used. + for (size_t j = vec_count * VECTOR_SIZE; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } } } } - } - #else - // basically fused operations: - // dst = r @ (time_faaaa * (k @ v) + state), - // state = time_decay * state + (k @ v), - // recursive through each token - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (size_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - for (size_t h = h_start; h < h_end; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; + for (size_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < head_size; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; + for (size_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - // RWKV v6: different time_decay for each token. - float time_decay_val = time_decay[t_h_i_offset]; + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; - for (size_t j = 0; j < head_size; j ++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; + for (size_t j = 0; j < head_size; j ++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } } } + } - - } #endif }