reduce 256 to 128 (and back!) conversions
This commit is contained in:
parent
9facb0f07a
commit
dee9566dc7
1 changed files with 24 additions and 27 deletions
45
sgemm.cpp
45
sgemm.cpp
|
@ -727,12 +727,28 @@ class tinyBLAS_Q0_AVX {
|
|||
for (int l = 0; l < k; ++l)
|
||||
for (int j = 0; j < RN; ++j)
|
||||
for (int i = 0; i < RM; ++i) {
|
||||
__m256 udTmp = updot(signepi8(load(A + lda * (ii + i) + l),
|
||||
#if defined(__AVX2__)
|
||||
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||
load(A + lda * (ii + i) + l)),
|
||||
signepi8(load(B + ldb * (jj + j) + l),
|
||||
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
||||
load(A + lda * (ii + i) + l)));
|
||||
//_mm256i ali = load(A + lda * (ii + i) + l;
|
||||
//_mm256i blj = load(B + ldb * (jj + j) + l;
|
||||
#else
|
||||
__m128i ali0 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 0);
|
||||
__m128i ali1 = _mm256_extractf128_si256(load(A + lda * (ii + i) + l), 1);
|
||||
__m128i blj0 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 0);
|
||||
__m128i blj1 = _mm256_extractf128_si256(load(B + ldb * (jj + j) + l), 1);
|
||||
|
||||
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
|
||||
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
|
||||
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
|
||||
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
|
||||
|
||||
// updot
|
||||
const __m128i oneFill = _mm_set1_epi16(1);
|
||||
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
|
||||
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
|
||||
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
|
||||
#endif
|
||||
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
||||
unhalf(B[ldb * (jj + j) + l].d)),
|
||||
udTmp,
|
||||
|
@ -764,31 +780,12 @@ class tinyBLAS_Q0_AVX {
|
|||
__m256i res;
|
||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
||||
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
||||
#elif defined(__AVX2__)
|
||||
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
||||
#else
|
||||
const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0));
|
||||
const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1));
|
||||
const __m128i oneFill = _mm_set1_epi16(1);
|
||||
res = MM256_SET_M128I(_mm_madd_epi16(oneFill, usMaddubs1), _mm_madd_epi16(oneFill, usMaddubs0));
|
||||
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
||||
#endif
|
||||
return _mm256_cvtepi32_ps(res);
|
||||
}
|
||||
|
||||
#if defined(__AVX2__)
|
||||
inline __m256i signepi8(__m256i a, __m256i b) {
|
||||
return _mm256_sign_epi8(a, b);
|
||||
}
|
||||
#else
|
||||
inline __m256i signepi8(__m256i a, __m256i b) {
|
||||
const __m128i a0 = _mm256_extractf128_si256(a, 0);
|
||||
const __m128i a1 = _mm256_extractf128_si256(a, 1);
|
||||
const __m128i b0 = _mm256_extractf128_si256(b, 0);
|
||||
const __m128i b1 = _mm256_extractf128_si256(b, 1);
|
||||
return MM256_SET_M128I(_mm_sign_epi8(a1, b1), _mm_sign_epi8(a0, b0));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline __m256i denibble(const uint8_t *p) {
|
||||
__m128i x = _mm_loadu_si128((const __m128i *)p);
|
||||
return _mm256_and_si256(_mm256_set1_epi8(15),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue