Fixes for clang AVX VNNI

This commit is contained in:
Srihari-mcw 2024-12-18 23:19:36 -08:00
parent 7909e8588d
commit 2a4e792ed1
3 changed files with 12 additions and 3 deletions

View file

@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
}
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_epi32(zero, ax, sy);
#elif defined(__AVXVNNI__)
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

View file

@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#elif defined(__AVXVNNI__)
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

View file

@ -993,8 +993,10 @@ class tinyBLAS_Q0_AVX {
inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#elif defined(__AVXVNNI__)
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
#else
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#endif