basic avx implementation

This commit is contained in:
netrunnereve 2024-04-22 23:35:02 -04:00
parent 4e96a812b3
commit 86d1d84642

View file

@ -1,6 +1,3 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation // Copyright 2024 Mozilla Foundation
// //
// Permission is hereby granted, free of charge, to any person obtaining // Permission is hereby granted, free of charge, to any person obtaining
@ -586,15 +583,15 @@ class tinyBLAS_Q0_ARM {
}; };
#endif // __ARM_FEATURE_DOTPROD #endif // __ARM_FEATURE_DOTPROD
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX2 { class tinyBLAS_Q0_AVX {
public: public:
tinyBLAS_Q0_AVX2(int k, tinyBLAS_Q0_AVX(int k,
const TA *A, int lda, const TA *A, int lda,
const TB *B, int ldb, const TB *B, int ldb,
TC *C, int ldc, TC *C, int ldc,
int ith, int nth) int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
} }
@ -732,9 +729,9 @@ class tinyBLAS_Q0_AVX2 {
for (int i = 0; i < RM; ++i) for (int i = 0; i < RM; ++i)
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)), unhalf(B[ldb * (jj + j) + l].d)),
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), updot(signepi8(load(A + lda * (ii + i) + l),
load(A + lda * (ii + i) + l)), load(A + lda * (ii + i) + l)),
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), signepi8(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l))), load(A + lda * (ii + i) + l))),
Cv[j][i]); Cv[j][i]);
for (int j = 0; j < RN; ++j) for (int j = 0; j < RN; ++j)
@ -748,24 +745,55 @@ class tinyBLAS_Q0_AVX2 {
} }
inline __m256i load(const block_q4_0 *b) { inline __m256i load(const block_q4_0 *b) {
#if defined(__AVX2__)
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
#else
const __m128i dn0 = _mm256_extractf128_si256(denibble(b->qs), 0);
const __m128i dn1 = _mm256_extractf128_si256(denibble(b->qs), 1);
return MM256_SET_M128I(_mm_sub_epi8(dn1, _mm_set1_epi8(8)), _mm_sub_epi8(dn0, _mm_set1_epi8(8)));
#endif
} }
inline __m256 updot(__m256i u, __m256i s) { inline __m256 updot(__m256i u, __m256i s) {
__m256i res; __m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#else #elif defined(__AVX2__)
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#else
const __m128i usMaddubs0 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 0), _mm256_extractf128_si256(s, 0));
const __m128i usMaddubs1 = _mm_maddubs_epi16(_mm256_extractf128_si256(u, 1), _mm256_extractf128_si256(s, 1));
const __m128i onefill = _mm_set1_epi16(1);
res = MM256_SET_M128I(_mm_madd_epi16(onefill, usMaddubs1), _mm_madd_epi16(onefill, usMaddubs0));
#endif #endif
return _mm256_cvtepi32_ps(res); return _mm256_cvtepi32_ps(res);
} }
#if defined(__AVX2__)
inline __m256i signepi8(__m256i a, __m256i b) {
return _mm256_sign_epi8(a, b);
}
#else
inline __m256i signepi8(__m256i a, __m256i b) {
const __m128i a0 = _mm256_extractf128_si256(a, 0);
const __m128i a1 = _mm256_extractf128_si256(a, 1);
const __m128i b0 = _mm256_extractf128_si256(b, 0);
const __m128i b1 = _mm256_extractf128_si256(b, 1);
return MM256_SET_M128I(_mm_sign_epi8(a1, b1), _mm_sign_epi8(a0, b0));
}
#endif
static inline __m256i denibble(const uint8_t *p) { static inline __m256i denibble(const uint8_t *p) {
__m128i x = _mm_loadu_si128((const __m128i *)p); __m128i x = _mm_loadu_si128((const __m128i *)p);
#if defined(__AVX2__)
return _mm256_and_si256(_mm256_set1_epi8(15), return _mm256_and_si256(_mm256_set1_epi8(15),
_mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm256_insertf128_si256(_mm256_castsi128_si256(x),
_mm_srli_epi16(x, 4), 1)); _mm_srli_epi16(x, 4), 1));
#else
const __m128i maskedLow = _mm_and_si128(_mm_set1_epi8(15), x);
const __m128i maskedHigh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
return MM256_SET_M128I(maskedHigh, maskedLow);
#endif
} }
const TA *const A; const TA *const A;
@ -778,7 +806,7 @@ class tinyBLAS_Q0_AVX2 {
const int ith; const int ith;
const int nth; const int nth;
}; };
#endif // __AVX2__ #endif // __AVX__
} // namespace } // namespace
@ -932,8 +960,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
k, (const block_q8_0 *)A, lda, k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
@ -956,8 +984,8 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B,
case GGML_TYPE_Q4_0: { case GGML_TYPE_Q4_0: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
k, (const block_q4_0 *)A, lda, k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,