From 063a31f7a880ca9810484b36090b47e3933d9202 Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 24 Apr 2024 23:00:02 -0400 Subject: [PATCH] sse load --- sgemm.cpp | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index f87652c82..c0eb998bc 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -733,10 +733,10 @@ class tinyBLAS_Q0_AVX { _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))); #else - __m128i ali0 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 0); - __m128i ali1 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 1); - __m128i blj0 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 0); - __m128i blj1 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 1); + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); @@ -764,16 +764,26 @@ class tinyBLAS_Q0_AVX { return _mm256_loadu_si256((const __m256i *)b->qs); } - inline __m256i load(const block_q4_0 *b) { -#if defined(__AVX2__) - return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); -#else - __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); - const __m128i dn0 = _mm_and_si128(_mm_set1_epi8(15), x); - const __m128i dn1 = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } - return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8))); -#endif + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + + inline __m256i load(const block_q4_0 *b) { + return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + } + + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); } inline __m256 updot(__m256i u, __m256i s) {