Help clang produce fma instructions

This commit is contained in:
Justine Tunney 2024-04-21 12:22:39 -07:00
parent 9d4d14c9b0
commit 6b220dca32
No known key found for this signature in database
GPG key ID: 52965314629936D4

View file

@ -107,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
/**
* Computes a * b + c.
*/
template <typename T, typename U>
inline U madd(T a, T b, U c) {
return add(mul(a, b), c);
}
#if defined(__FMA__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
return _mm256_fmadd_ps(a, b, c);
}
#endif
#if defined(__AVX512F__)
template <>
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
return _mm512_fmadd_ps(a, b, c);
}
#endif
#endif
#if defined(__ARM_FEATURE_FMA)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vfmaq_f32(c, b, a);
}
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
template <>
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
return vfmaq_f16(c, b, a);
}
#endif
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM
@ -198,21 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
}
#endif // __AVX512F__
////////////////////////////////////////////////////////////////////////////////////////////////////
// ABSTRACTIONS
/**
* Computes a * b + c.
*
* This operation will become fused into a single arithmetic instruction
* if the hardware has support for this feature, e.g. Intel Haswell+ (c.
* 2013), AMD Bulldozer+ (c. 2011), etc.
*/
template <typename T, typename U>
inline U madd(T a, T b, U c) {
return add(mul(a, b), c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION