From dee9566dc7e7d38a0c43cbed8c5e5678b8d4e46a Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:22:38 -0400 Subject: [PATCH] reduce 256 to 128 (and back!) conversions --- sgemm.cpp | 51 ++++++++++++++++++++++++--------------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 059674d1d..f87652c82 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -727,17 +727,33 @@ class tinyBLAS_Q0_AVX { for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) { - __m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - signepi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))); - //_mm256i ali = load(A + lda * (ii + i) + l; - //_mm256i blj = load(B + ldb * (jj + j) + l; +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 0); + __m128i ali1 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 1); + __m128i blj0 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 0); + __m128i blj1 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 1); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), udTmp, Cv[j][i]); - } + } for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); @@ -764,31 +780,12 @@ class tinyBLAS_Q0_AVX { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); -#elif defined(__AVX2__) - res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #else - const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0)); - const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1)); - const __m128i oneFill = _mm_set1_epi16(1); - res = MM256_SET_M128I(_mm_madd_epi16(oneFill, usMaddubs1), _mm_madd_epi16(oneFill, usMaddubs0)); + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif return _mm256_cvtepi32_ps(res); } -#if defined(__AVX2__) - inline __m256i signepi8(__m256i a, __m256i b) { - return _mm256_sign_epi8(a, b); - } -#else - inline __m256i signepi8(__m256i a, __m256i b) { - const __m128i a0 = _mm256_extractf128_si256(a, 0); - const __m128i a1 = _mm256_extractf128_si256(a, 1); - const __m128i b0 = _mm256_extractf128_si256(b, 0); - const __m128i b1 = _mm256_extractf128_si256(b, 1); - return MM256_SET_M128I(_mm_sign_epi8(a1, b1), _mm_sign_epi8(a0, b0)); - } -#endif - static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); return _mm256_and_si256(_mm256_set1_epi8(15),