reduce 256 to 128 (and back!) conversions

This commit is contained in:
netrunnereve 2024-04-24 00:22:38 -04:00
parent 9facb0f07a
commit dee9566dc7

View file

@ -727,17 +727,33 @@ class tinyBLAS_Q0_AVX {
for (int l = 0; l < k; ++l) for (int l = 0; l < k; ++l)
for (int j = 0; j < RN; ++j) for (int j = 0; j < RN; ++j)
for (int i = 0; i < RM; ++i) { for (int i = 0; i < RM; ++i) {
__m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l), #if defined(__AVX2__)
load(A + lda * (ii + i) + l)), __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
signepi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l)),
load(A + lda * (ii + i) + l))); _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
//_mm256i ali = load(A + lda * (ii + i) + l; load(A + lda * (ii + i) + l)));
//_mm256i blj = load(B + ldb * (jj + j) + l; #else
__m128i ali0 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 0);
__m128i ali1 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 1);
__m128i blj0 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 0);
__m128i blj1 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 1);
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
// updot
const __m128i oneFill = _mm_set1_epi16(1);
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
#endif
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)), unhalf(B[ldb * (jj + j) + l].d)),
udTmp, udTmp,
Cv[j][i]); Cv[j][i]);
} }
for (int j = 0; j < RN; ++j) for (int j = 0; j < RN; ++j)
for (int i = 0; i < RM; ++i) for (int i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
@ -764,31 +780,12 @@ class tinyBLAS_Q0_AVX {
__m256i res; __m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#elif defined(__AVX2__)
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#else #else
const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0)); res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1));
const __m128i oneFill = _mm_set1_epi16(1);
res = MM256_SET_M128I(_mm_madd_epi16(oneFill, usMaddubs1), _mm_madd_epi16(oneFill, usMaddubs0));
#endif #endif
return _mm256_cvtepi32_ps(res); return _mm256_cvtepi32_ps(res);
} }
#if defined(__AVX2__)
inline __m256i signepi8(__m256i a, __m256i b) {
return _mm256_sign_epi8(a, b);
}
#else
inline __m256i signepi8(__m256i a, __m256i b) {
const __m128i a0 = _mm256_extractf128_si256(a, 0);
const __m128i a1 = _mm256_extractf128_si256(a, 1);
const __m128i b0 = _mm256_extractf128_si256(b, 0);
const __m128i b1 = _mm256_extractf128_si256(b, 1);
return MM256_SET_M128I(_mm_sign_epi8(a1, b1), _mm_sign_epi8(a0, b0));
}
#endif
static inline __m256i denibble(const uint8_t *p) { static inline __m256i denibble(const uint8_t *p) {
__m128i x = _mm_loadu_si128((const __m128i *)p); __m128i x = _mm_loadu_si128((const __m128i *)p);
return _mm256_and_si256(_mm256_set1_epi8(15), return _mm256_and_si256(_mm256_set1_epi8(15),