add fall back for axpy mulmat
This commit is contained in:
parent
a3c295a2ae
commit
6f997d299a
1 changed files with 8 additions and 1 deletions
9
ggml.c
9
ggml.c
|
@ -14298,7 +14298,7 @@ static void ggml_axpy_normal_f16(const int n, const ggml_fp16_t * vx, const ggml
|
|||
}
|
||||
}
|
||||
static void ggml_axpy_avx_f16(const int n, const ggml_fp16_t * restrict vx, const ggml_fp16_t * restrict vy, void* restrict vz, ggml_fp16_t alpha) {
|
||||
|
||||
#if defined(__AVX2__)
|
||||
float *result = (float *)vz;
|
||||
float alpha_f32 = GGML_FP16_TO_FP32(alpha);
|
||||
__m256 scale = _mm256_set1_ps(alpha_f32); // 创建scale向量
|
||||
|
@ -14309,6 +14309,13 @@ static void ggml_axpy_avx_f16(const int n, const ggml_fp16_t * restrict vx, cons
|
|||
__m256 res = _mm256_fmadd_ps(vx_f32, scale, vy_f32); // 执行向量加法和乘法操作
|
||||
_mm256_storeu_ps((float*)(&result[i]), res); // 存储结果
|
||||
}
|
||||
#else
|
||||
float *res = (float *)vz;
|
||||
float alpha_convert = GGML_FP16_TO_FP32(alpha);
|
||||
for (int i = 0; i < n; i++) {
|
||||
res[i] = res[i] + (GGML_FP16_TO_FP32(vx[i])*alpha_convert);
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
atomic_flag g_axpy_dense_lock = ATOMIC_FLAG_INIT;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue