add fall back for axpy mulmat
This commit is contained in:
parent
a3c295a2ae
commit
6f997d299a
1 changed files with 8 additions and 1 deletions
9
ggml.c
9
ggml.c
|
@ -14298,7 +14298,7 @@ static void ggml_axpy_normal_f16(const int n, const ggml_fp16_t * vx, const ggml
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
static void ggml_axpy_avx_f16(const int n, const ggml_fp16_t * restrict vx, const ggml_fp16_t * restrict vy, void* restrict vz, ggml_fp16_t alpha) {
|
static void ggml_axpy_avx_f16(const int n, const ggml_fp16_t * restrict vx, const ggml_fp16_t * restrict vy, void* restrict vz, ggml_fp16_t alpha) {
|
||||||
|
#if defined(__AVX2__)
|
||||||
float *result = (float *)vz;
|
float *result = (float *)vz;
|
||||||
float alpha_f32 = GGML_FP16_TO_FP32(alpha);
|
float alpha_f32 = GGML_FP16_TO_FP32(alpha);
|
||||||
__m256 scale = _mm256_set1_ps(alpha_f32); // 创建scale向量
|
__m256 scale = _mm256_set1_ps(alpha_f32); // 创建scale向量
|
||||||
|
@ -14309,6 +14309,13 @@ static void ggml_axpy_avx_f16(const int n, const ggml_fp16_t * restrict vx, cons
|
||||||
__m256 res = _mm256_fmadd_ps(vx_f32, scale, vy_f32); // 执行向量加法和乘法操作
|
__m256 res = _mm256_fmadd_ps(vx_f32, scale, vy_f32); // 执行向量加法和乘法操作
|
||||||
_mm256_storeu_ps((float*)(&result[i]), res); // 存储结果
|
_mm256_storeu_ps((float*)(&result[i]), res); // 存储结果
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
float *res = (float *)vz;
|
||||||
|
float alpha_convert = GGML_FP16_TO_FP32(alpha);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
res[i] = res[i] + (GGML_FP16_TO_FP32(vx[i])*alpha_convert);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
atomic_flag g_axpy_dense_lock = ATOMIC_FLAG_INIT;
|
atomic_flag g_axpy_dense_lock = ATOMIC_FLAG_INIT;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue