combine denibble with load

This commit is contained in:
netrunnereve 2024-04-23 23:46:49 -04:00
parent 257391aae3
commit 9facb0f07a

View file

@ -726,14 +726,18 @@ class tinyBLAS_Q0_AVX {
__m256 Cv[RN][RM] = {}; __m256 Cv[RN][RM] = {};
for (int l = 0; l < k; ++l) for (int l = 0; l < k; ++l)
for (int j = 0; j < RN; ++j) for (int j = 0; j < RN; ++j)
for (int i = 0; i < RM; ++i) for (int i = 0; i < RM; ++i) {
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * __m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l),
unhalf(B[ldb * (jj + j) + l].d)),
updot(signepi8(load(A + lda * (ii + i) + l),
load(A + lda * (ii + i) + l)), load(A + lda * (ii + i) + l)),
signepi8(load(B + ldb * (jj + j) + l), signepi8(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l))), load(A + lda * (ii + i) + l)));
Cv[j][i]); //_mm256i ali = load(A + lda * (ii + i) + l;
//_mm256i blj = load(B + ldb * (jj + j) + l;
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)),
udTmp,
Cv[j][i]);
}
for (int j = 0; j < RN; ++j) for (int j = 0; j < RN; ++j)
for (int i = 0; i < RM; ++i) for (int i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
@ -748,8 +752,10 @@ class tinyBLAS_Q0_AVX {
#if defined(__AVX2__) #if defined(__AVX2__)
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
#else #else
const __m128i dn0 = _mm256_extractf128_si256(denibble(b->qs), 0); __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
const __m128i dn1 = _mm256_extractf128_si256(denibble(b->qs), 1); const __m128i dn0 = _mm_and_si128(_mm_set1_epi8(15), x);
const __m128i dn1 = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8))); return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8)));
#endif #endif
} }
@ -785,15 +791,9 @@ class tinyBLAS_Q0_AVX {
static inline __m256i denibble(const uint8_t *p) { static inline __m256i denibble(const uint8_t *p) {
__m128i x = _mm_loadu_si128((const __m128i *)p); __m128i x = _mm_loadu_si128((const __m128i *)p);
#if defined(__AVX2__)
return _mm256_and_si256(_mm256_set1_epi8(15), return _mm256_and_si256(_mm256_set1_epi8(15),
_mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm256_insertf128_si256(_mm256_castsi128_si256(x),
_mm_srli_epi16(x, 4), 1)); _mm_srli_epi16(x, 4), 1));
#else
const __m128i maskedLow = _mm_and_si128(_mm_set1_epi8(15), x);
const __m128i maskedHigh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
return MM256_SET_M128I(maskedHigh, maskedLow);
#endif
} }
const TA *const A; const TA *const A;