From 9facb0f07a99383ae0eef8de2f044ba52902649b Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Tue, 23 Apr 2024 23:46:49 -0400 Subject: [PATCH] combine denibble with load --- sgemm.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 5fd18549f..059674d1d 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -726,14 +726,18 @@ class tinyBLAS_Q0_AVX { __m256 Cv[RN][RM] = {}; for (int l = 0; l < k; ++l) for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - updot(signepi8(load(A + lda * (ii + i) + l), + for (int i = 0; i < RM; ++i) { + __m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l), load(A + lda * (ii + i) + l)), signepi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), - Cv[j][i]); + load(A + lda * (ii + i) + l))); + //_mm256i ali = load(A + lda * (ii + i) + l; + //_mm256i blj = load(B + ldb * (jj + j) + l; + Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + udTmp, + Cv[j][i]); + } for (int j = 0; j < RN; ++j) for (int i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); @@ -748,8 +752,10 @@ class tinyBLAS_Q0_AVX { #if defined(__AVX2__) return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); #else - const __m128i dn0 = _mm256_extractf128_si256(denibble(b->qs), 0); - const __m128i dn1 = _mm256_extractf128_si256(denibble(b->qs), 1); + __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + const __m128i dn0 = _mm_and_si128(_mm_set1_epi8(15), x); + const __m128i dn1 = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8))); #endif } @@ -785,15 +791,9 @@ class tinyBLAS_Q0_AVX { static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); -#if defined(__AVX2__) return _mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)); -#else - const __m128i maskedLow = _mm_and_si128(_mm_set1_epi8(15), x); - const __m128i maskedHigh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); - return MM256_SET_M128I(maskedHigh, maskedLow); -#endif } const TA *const A;