fix define error

This commit is contained in:
Zhiyuan Li 2024-11-05 01:07:33 +11:00
parent 81cb301224
commit a878502f43

View file

@ -11679,17 +11679,16 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
}
ggml_barrier(params->threadpool);
#ifdef __AVX2__
#if defined(__AVX__) && !defined(__AVX512F__)
#define GGML_F32X GGML_F32x8
#define GGML_F32X_SET1 GGML_F32x8_SET1
#define GGML_F32X_LOAD GGML_F32x8_LOAD
#define GGML_F32X_STORE GGML_F32x8_STORE
#define GGML_F32X_MUL GGML_F32x8_MUL
#define GGML_F32X_FMA GGML_F32x8_FMA
#define VECTOR_SIZE 8
#elif __AVX512F__
#define WKV_VECTOR_SIZE 8
#elif defined(__AVX512F__)
#define GGML_F32X GGML_F32x16
#define GGML_F32X_SET1 GGML_F32x16_SET1
#define GGML_F32X_LOAD GGML_F32x16_LOAD
@ -11763,7 +11762,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
}
// Handle remaining elements, this will not be used.
for (int64_t j = vec_count * VECTOR_SIZE; j < head_size; j++) {
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
float v_val = v[t_h_j_offset];