From 53ad4bd89fb41fdaf2cd458f42043c545e5dcb4b Mon Sep 17 00:00:00 2001 From: Tianzhengshuyuan <298208490@qq.com> Date: Thu, 25 Jul 2024 22:32:41 +0800 Subject: [PATCH] Add support for loongarch backend in sgemm.cpp --- ggml/src/llamafile/sgemm.cpp | 323 ++++++++++++++++++++++++++++++++++- 1 file changed, 322 insertions(+), 1 deletion(-) diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index 6626ceb26..3f869ffcc 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -58,7 +58,7 @@ #define NOINLINE __attribute__((__noinline__)) #endif -#if defined(__ARM_NEON) || defined(__AVX512F__) +#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__loongarch_asx) #define VECTOR_REGISTERS 32 #else #define VECTOR_REGISTERS 16 @@ -66,6 +66,15 @@ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +#ifdef __clang__ +#define VREGS_PREFIX "$vr" +#define XREGS_PREFIX "$xr" +#else // GCC +#define VREGS_PREFIX "$f" +#define XREGS_PREFIX "$f" +#endif +#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" + namespace { inline float unhalf(ggml_fp16_t d) { @@ -105,6 +114,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if defined(__loongarch_asx) +inline __m256 add(__m256 x, __m256 y) { return __lasx_xvfadd_s(x, y); } +inline __m256 sub(__m256 x, __m256 y) { return __lasx_xvfsub_s(x, y); } +inline __m256 mul(__m256 x, __m256 y) { return __lasx_xvfmul_s(x, y); } +#endif // __loongarch_asx + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -144,6 +159,12 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { #endif #endif +#if defined(__loongarch_asx) +template <> +inline __m256 madd(__m256 a, __m256 b, __m256 c) { + return __lasx_xvfmadd_s(a, b, c); +} +#endif //__loongarch_asx //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -189,6 +210,68 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ +#if defined(__loongarch_asx) +static inline __m128i lasx_extracti128_lo(__m256i in) { + __m128i out; + __asm__ volatile ( + ".ifnc %[out], %[in] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " vori.b $vr\\i, $vr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static inline __m128i lasx_extracti128_hi(__m256i in) { + __m128i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static __m128 lasx_extractf128( __m256 a, int pos) { + __m128 ret; + if( pos == 0) + { + ret = (__m128)lasx_extracti128_lo((__m256i)a); + } else { + ret = (__m128)lasx_extracti128_hi((__m256i)a); + } + return ret; +} +static inline float hsum_float_8(const __m256 x) { + __m128 res = lasx_extractf128(x, 1); + ft_union tmp; + res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); + tmp.i = __lsx_vpickve2gr_w(res, 0); + return tmp.f; +} + +inline float hsum(__m256 x) { + return hsum_float_8(x); +} +#endif // __loongarch_asx + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED MEMORY LOADING @@ -235,6 +318,47 @@ template <> inline __m512 load(const ggml_fp16_t *p) { } #endif // __AVX512F__ +#if defined(__loongarch_asx) +template <> inline __m256 load(const float *p) { + return (__m256)__lasx_xvld(p, 0); +} +static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { + __m256i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"//拼接hi和lo + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".ifnc %[out], %[hi] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" + " xvori.b $xr\\i, $xr\\j, 0 \n\t"//复制hi到out + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out), [hi] "+f" (inhi) + : [lo] "f" (inlo) + ); + return out; +} +template <> inline __m256 load(const ggml_fp16_t *p) { + __m128i vector16 = __lsx_vld((const __m128i *)p, 0); + __m128i hi = (__m128i)__lsx_vfcvth_s_h(vector16); + __m128i lo = (__m128i)__lsx_vfcvtl_s_h(vector16); + return (__m256)lasx_set_q(hi,lo); + +} +#endif // __loongarch_asx + //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -813,6 +937,165 @@ class tinyBLAS_Q0_AVX { }; #endif // __AVX__ +#if defined(__loongarch_asx) +template +class tinyBLAS_Q0_LOONGARCH { + public: + tinyBLAS_Q0_LOONGARCH(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 1)) { + case 0x44: + case 0x43: + case 0x42: + case 0x41: + mc = 4; + nc = 1; + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + inline __m256i load(const block_q8_0 *b) { + return __lasx_xvld(((const __m256i *)b->qs), 0); + } + + static __m256i lasx_insertf128( __m128i x, __m128i y) { + return lasx_set_q(x, y); + } + + // Unpack 32 4-bit fields into 32 bytes + // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval + static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); + } + + inline __m256i load(const block_q4_0 *b) { + return __lasx_xvsub_b(bytes_from_nibbles_32(b->qs), __lasx_xvreplgr2vr_b(8)); + } + + // add int16_t pairwise and return as float vector + static inline __m256 sum_i16_pairs_float(const __m256i x) { + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); + return __lasx_xvffint_s_w(summed_pairs); + } + + static __m256i lasx_maddubs_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvsadd_h(tmp1, tmp2); + } + static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m256i dot = lasx_maddubs_h(ax, sy); + return sum_i16_pairs_float(dot); + } + + static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m256i ax = __lasx_xvsigncov_b(x, x); + const __m256i sy = __lasx_xvsigncov_b(x, y); + return mul_sum_us8_pairs_float(ax, sy); + } + + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) { + __m256 udTmp = mul_sum_i8_pairs_float(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l)); + + Cv[j][i] = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + udTmp, + Cv[j][i]); + } + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum_float_8(Cv[j][i]); + } + } + + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif // __loongarch_asx } // namespace /** @@ -897,6 +1180,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda ith, nth}; tb.matmul(m, n); return true; +// #elif defined(__loongarch_asx) +// if (k % 8) +// return false; +// tinyBLAS<8, __m256, __m256, float, float, float> tb{ +// k, (const float *)A, lda, +// (const float *)B, ldb, +// (float *)C, ldc, +// ith, nth}; +// tb.matmul(m, n); +// return true; #else return false; #endif @@ -953,6 +1246,18 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda ith, nth}; tb.matmul(m, n); return true; +#elif defined(__loongarch_asx) + if (k % 8) + return false; + if (Btype != GGML_TYPE_F32) + return false; + tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; #else return false; #endif @@ -977,6 +1282,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda ith, nth}; tb.matmul(m, n); return true; +#elif defined(__loongarch_asx) + tinyBLAS_Q0_LOONGARCH tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; #else return false; #endif @@ -1001,6 +1314,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda ith, nth}; tb.matmul(m, n); return true; +#elif defined(__loongarch_asx) + tinyBLAS_Q0_LOONGARCH tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n); + return true; #else return false; #endif