Help clang produce fma instructions
This commit is contained in:
parent
9d4d14c9b0
commit
6b220dca32
1 changed files with 39 additions and 15 deletions
54
sgemm.cpp
54
sgemm.cpp
|
@ -107,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
||||||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// VECTORIZED FUSED MULTIPLY ADD
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes a * b + c.
|
||||||
|
*/
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline U madd(T a, T b, U c) {
|
||||||
|
return add(mul(a, b), c);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__FMA__)
|
||||||
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||||
|
template <>
|
||||||
|
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
|
||||||
|
return _mm256_fmadd_ps(a, b, c);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#if defined(__AVX512F__)
|
||||||
|
template <>
|
||||||
|
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
||||||
|
return _mm512_fmadd_ps(a, b, c);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_FMA)
|
||||||
|
template <>
|
||||||
|
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
||||||
|
return vfmaq_f32(c, b, a);
|
||||||
|
}
|
||||||
|
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
||||||
|
template <>
|
||||||
|
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
||||||
|
return vfmaq_f16(c, b, a);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// VECTORIZED HORIZONTAL SUM
|
// VECTORIZED HORIZONTAL SUM
|
||||||
|
|
||||||
|
@ -198,21 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
||||||
}
|
}
|
||||||
#endif // __AVX512F__
|
#endif // __AVX512F__
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// ABSTRACTIONS
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Computes a * b + c.
|
|
||||||
*
|
|
||||||
* This operation will become fused into a single arithmetic instruction
|
|
||||||
* if the hardware has support for this feature, e.g. Intel Haswell+ (c.
|
|
||||||
* 2013), AMD Bulldozer+ (c. 2011), etc.
|
|
||||||
*/
|
|
||||||
template <typename T, typename U>
|
|
||||||
inline U madd(T a, T b, U c) {
|
|
||||||
return add(mul(a, b), c);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// FLOATING POINT MATRIX MULTIPLICATION
|
// FLOATING POINT MATRIX MULTIPLICATION
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue