From 8a0d9a304f33a705a07865712b55ff511fb88b5a Mon Sep 17 00:00:00 2001 From: junchao-loongson Date: Sun, 19 May 2024 16:32:40 +0800 Subject: [PATCH] update --- CMakeLists.txt | 2 +- ggml-impl.h | 413 +++++++++++++++++++++++++++++++++++++++++++ ggml-quants.c | 463 +++---------------------------------------------- ggml.c | 27 +-- 4 files changed, 442 insertions(+), 463 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 62da34ded..f46716216 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1132,7 +1132,7 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "loongarch64") +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") message(STATUS "loongarch64 detected") list(APPEND ARCH_FLAGS -march=loongarch64) diff --git a/ggml-impl.h b/ggml-impl.h index 3854c2985..fdc8e0006 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -452,6 +452,419 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #endif #endif +#if defined(__loongarch_asx) + +typedef union { + int32_t i; + float f; +} ft_union; + +/* float type data load instructions */ +static __m128 __lsx_vreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); +} + +static __m256 __lasx_xvreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); +} + +#ifdef __clang__ +#define VREGS_PREFIX "$vr" +#define XREGS_PREFIX "$xr" +#else // GCC +#define VREGS_PREFIX "$f" +#define XREGS_PREFIX "$f" +#endif +#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" +// Convert __m128i to __m256i +static inline __m256i ____m256i(__m128i in) { + __m256i out = __lasx_xvldi(0); + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX"\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "+f" (out) : [in] "f" (in) + ); + return out; +} +// Convert two __m128i to __m256i +static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { + __m256i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".ifnc %[out], %[hi] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" + " xvori.b $xr\\i, $xr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out), [hi] "+f" (inhi) + : [lo] "f" (inlo) + ); + return out; +} +// Convert __m256i low part to __m128i +static inline __m128i lasx_extracti128_lo(__m256i in) { + __m128i out; + __asm__ volatile ( + ".ifnc %[out], %[in] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " vori.b $vr\\i, $vr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} +// Convert __m256i high part to __m128i +static inline __m128i lasx_extracti128_hi(__m256i in) { + __m128i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { + v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; + return (__m256i)__ret; +} + +static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { + v4i32 __ret = {d, c, b, a}; + return (__m128i)__ret; +} + +static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { + v4i64 __ret = {d, c, b, a}; + return (__m256i)__ret; +} + +static __m256i lasx_insertf128( __m128i x, __m128i y) { + return lasx_set_q(x, y); +} + +static __m128i lsx_shuffle_b(__m128i a, __m128i b) { + __m128i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lsx_vreplgr2vr_b(f); + zero = __lsx_vldi(0); + tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones + return __lsx_vshuf_b(a, zero, tmp2); +} + +static __m256i lasx_shuffle_b(__m256i a, __m256i b) { + __m256i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lasx_xvreplgr2vr_b(f); + zero = __lasx_xvldi(0); + tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones + return __lasx_xvshuf_b(a, zero, tmp2); +} + +static __m256i lasx_extu8_16(__m128i a) { + __m128i zero = __lsx_vldi(0); + __m128i vlo = __lsx_vilvl_b(zero, a); + __m128i vhi = __lsx_vilvh_b(zero, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext8_16(__m128i a) { + __m128i sign = __lsx_vslti_b(a, 0); + __m128i vlo = __lsx_vilvl_b(sign, a); + __m128i vhi = __lsx_vilvh_b(sign, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext16_32(__m128i a) { + __m256i tmp1; + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); + return tmp1; +} + +static __m128i lasx_extracti128( __m256i a, int pos) { + __m128i ret; + if( pos == 0) + { + ret = lasx_extracti128_lo(a); + } else { + ret = lasx_extracti128_hi(a); + } + return ret; +} + +static __m128 lasx_extractf128( __m256 a, int pos) { + __m128 ret; + if( pos == 0) + { + ret = (__m128)lasx_extracti128_lo((__m256i)a); + } else { + ret = (__m128)lasx_extracti128_hi((__m256i)a); + } + return ret; +} + +static __m128i lsx_hadd_h(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_h(b, a); + __m128i tmp2 = __lsx_vpickod_h(b, a); + return __lsx_vadd_h(tmp1, tmp2); +} + +static __m128i lsx_hadd_w(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_w(b, a); + __m128i tmp2 = __lsx_vpickod_w(b, a); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128 lsx_hadd_s(__m128 a, __m128 b) { + __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); + __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); + + return __lsx_vfadd_s(tmp1, tmp2); +} + +static __m256i lasx_maddubs_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvsadd_h(tmp1, tmp2); +} + +static __m256i lasx_madd_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_w_h(a, b); + tmp2 = __lasx_xvmulwod_w_h(a, b); + return __lasx_xvadd_w(tmp1, tmp2); +} + +static __m256i lasx_packs_w(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_w(a, 15); + tmp1 = __lasx_xvsat_w(b, 15); + return __lasx_xvpickev_h(tmp1, tmp); +} + +static __m256i lasx_packs_h(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_h(a, 7); + tmp1 = __lasx_xvsat_h(b, 7); + return __lasx_xvpickev_b(tmp1, tmp); +} + +static __m128i lsx_packs_w(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(a, 15); + tmp1 = __lsx_vsat_w(b, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +static __m128i lsx_packs_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_packus_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_hu(a, 7); + tmp1 = __lsx_vsat_hu(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + + +static __m128i lsx_maddubs_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_h_b(a, b); + tmp2 = __lsx_vmulwod_h_b(a, b); + return __lsx_vsadd_h(tmp1, tmp2); +} + +static __m128i lsx_madd_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_w_h(a, b); + tmp2 = __lsx_vmulwod_w_h(a, b); + return __lsx_vadd_w(tmp1, tmp2); +} + +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = __lsx_vsigncov_b(x, x); + // Sign the values of the y vectors + const __m128i sy = __lsx_vsigncov_b(x, y); + // Perform multiplication and create 16-bit values + const __m128i dot = lsx_maddubs_h(ax, sy); + const __m128i ones = __lsx_vreplgr2vr_h(1); + return lsx_madd_h(ones, dot); +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = lasx_extractf128(x, 1); + ft_union tmp; + res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); + tmp.i = __lsx_vpickve2gr_w(res, 0); + return tmp.f; +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + + __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); + __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); + + __m128i tmp1_128 = lasx_extracti128_lo(tmp1); + __m128i tmp2_128 = lasx_extracti128_lo(tmp2); + + __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); + + __m128i ev = __lsx_vpickev_w(sum128, sum128); + __m128i od = __lsx_vpickod_w(sum128, sum128); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + __m128i ev = __lsx_vpickev_w(a, a); + __m128i od = __lsx_vpickod_w(a, a); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = lasx_set_d( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + + __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); + const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); + bytes = __lasx_xvor_v(bytes, bit_mask); + return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); + return __lasx_xvffint_s_w(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + // Perform multiplication and create 16-bit values + const __m256i dot = lasx_maddubs_h(ax, sy); + return sum_i16_pairs_float(dot); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + + // Get absolute values of x vectors + const __m256i ax = __lasx_xvsigncov_b(x, x); + // Sign the values of the y vectors + const __m256i sy = __lasx_xvsigncov_b(x, y); + + return mul_sum_us8_pairs_float(ax, sy); +} + +static inline __m128i packNibbles( __m256i bytes ) { + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); + __m256i high = __lasx_xvandn_v(lowByte, bytes); + __m256i low = __lasx_xvand_v(lowByte, bytes); + high = __lasx_xvsrli_h(high, 4); + bytes = __lasx_xvor_v(low, high); + // Compress uint16_t lanes into bytes + __m128i *r0 = (__m128i *)&bytes; + __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); + __m128i *r1 = (__m128i *)&tmp_h128; + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, *r0); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, *r1); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); +} +#endif + #ifdef __F16C__ #ifdef _MSC_VER diff --git a/ggml-quants.c b/ggml-quants.c index e10315678..585286322 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -262,421 +262,6 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif -#if defined(__loongarch_asx) - -typedef union { - int32_t i; - float f; -} ft_union; - -/* float type data load instructions */ -static __m128 __lsx_vreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); -} - -static __m256 __lasx_xvreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); -} - -#ifdef __clang__ -#define VREGS_PREFIX "$vr" -#define XREGS_PREFIX "$xr" -#else // GCC -#define VREGS_PREFIX "$f" -#define XREGS_PREFIX "$f" -#endif -#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" -// Convert __m128i to __m256i -static inline __m256i ____m256i(__m128i in) { - __m256i out = __lasx_xvldi(0); - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX"\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "+f" (out) : [in] "f" (in) - ); - return out; -} -// Convert two __m128i to __m256i -static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { - __m256i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".ifnc %[out], %[hi] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" - " xvori.b $xr\\i, $xr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out), [hi] "+f" (inhi) - : [lo] "f" (inlo) - ); - return out; -} -// Convert __m256i low part to __m128i -static inline __m128i lasx_extracti128_lo(__m256i in) { - __m128i out; - __asm__ volatile ( - ".ifnc %[out], %[in] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " vori.b $vr\\i, $vr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} -// Convert __m256i high part to __m128i -static inline __m128i lasx_extracti128_hi(__m256i in) { - __m128i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} - -static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { - v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; - return (__m256i)__ret; -} - -static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { - v4i32 __ret = {d, c, b, a}; - return (__m128i)__ret; -} - -static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { - v4i64 __ret = {d, c, b, a}; - return (__m256i)__ret; -} - -static __m256i lasx_insertf128( __m128i x, __m128i y) { - return lasx_set_q(x, y); -} -#undef MM256_SET_M128I -#define MM256_SET_M128I(a, b) lasx_insertf128((a), (b)) - -static __m128i lsx_shuffle_b(__m128i a, __m128i b) { - __m128i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lsx_vreplgr2vr_b(f); - zero = __lsx_vldi(0); - tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones - return __lsx_vshuf_b(a, zero, tmp2); -} - -static __m256i lasx_shuffle_b(__m256i a, __m256i b) { - __m256i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lasx_xvreplgr2vr_b(f); - zero = __lasx_xvldi(0); - tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones - return __lasx_xvshuf_b(a, zero, tmp2); -} - -static __m256i lasx_extu8_16(__m128i a) { - __m128i zero = __lsx_vldi(0); - __m128i vlo = __lsx_vilvl_b(zero, a); - __m128i vhi = __lsx_vilvh_b(zero, a); - return lasx_set_q(vhi, vlo); -} - -static __m256i lasx_ext8_16(__m128i a) { - __m128i sign = __lsx_vslti_b(a, 0); - __m128i vlo = __lsx_vilvl_b(sign, a); - __m128i vhi = __lsx_vilvh_b(sign, a); - return lasx_set_q(vhi, vlo); -} - -static __m256i lasx_ext16_32(__m128i a) { - __m256i tmp1; - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); - return tmp1; -} - -static __m128i lasx_extracti128( __m256i a, int pos) { - __m128i ret; - if( pos == 0) - { - ret = lasx_extracti128_lo(a); - } else { - ret = lasx_extracti128_hi(a); - } - return ret; -} - -static __m128 lasx_extractf128( __m256 a, int pos) { - __m128 ret; - if( pos == 0) - { - ret = (__m128)lasx_extracti128_lo((__m256i)a); - } else { - ret = (__m128)lasx_extracti128_hi((__m256i)a); - } - return ret; -} - -static __m128i lsx_hadd_h(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_h(b, a); - __m128i tmp2 = __lsx_vpickod_h(b, a); - return __lsx_vadd_h(tmp1, tmp2); -} - -static __m128i lsx_hadd_w(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_w(b, a); - __m128i tmp2 = __lsx_vpickod_w(b, a); - return __lsx_vadd_w(tmp1, tmp2); -} - -static __m128 lsx_hadd_s(__m128 a, __m128 b) { - __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); - __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); - - return __lsx_vfadd_s(tmp1, tmp2); -} - -static __m256i lasx_maddubs_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_h_b(a, b); - tmp2 = __lasx_xvmulwod_h_b(a, b); - return __lasx_xvsadd_h(tmp1, tmp2); -} - -static __m256i lasx_madd_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_w_h(a, b); - tmp2 = __lasx_xvmulwod_w_h(a, b); - return __lasx_xvadd_w(tmp1, tmp2); -} - -static __m256i lasx_packs_w(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_w(a, 15); - tmp1 = __lasx_xvsat_w(b, 15); - return __lasx_xvpickev_h(tmp1, tmp); -} - -static __m256i lasx_packs_h(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_h(a, 7); - tmp1 = __lasx_xvsat_h(b, 7); - return __lasx_xvpickev_b(tmp1, tmp); -} - -static __m128i lsx_packs_w(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_w(a, 15); - tmp1 = __lsx_vsat_w(b, 15); - return __lsx_vpickev_h(tmp1, tmp); -} - -static __m128i lsx_packs_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_h(a, 7); - tmp1 = __lsx_vsat_h(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - -static __m128i lsx_packus_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_hu(a, 7); - tmp1 = __lsx_vsat_hu(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - - -static __m128i lsx_maddubs_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_h_b(a, b); - tmp2 = __lsx_vmulwod_h_b(a, b); - return __lsx_vsadd_h(tmp1, tmp2); -} - -static __m128i lsx_madd_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_w_h(a, b); - tmp2 = __lsx_vmulwod_w_h(a, b); - return __lsx_vadd_w(tmp1, tmp2); -} - -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = __lsx_vsigncov_b(x, x); - // Sign the values of the y vectors - const __m128i sy = __lsx_vsigncov_b(x, y); - // Perform multiplication and create 16-bit values - const __m128i dot = lsx_maddubs_h(ax, sy); - const __m128i ones = __lsx_vreplgr2vr_h(1); - return lsx_madd_h(ones, dot); -} - -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = lasx_extractf128(x, 1); - ft_union tmp; - res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); - tmp.i = __lsx_vpickve2gr_w(res, 0); - return tmp.f; -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - - __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); - __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); - - __m128i tmp1_128 = lasx_extracti128_lo(tmp1); - __m128i tmp2_128 = lasx_extracti128_lo(tmp2); - - __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); - - __m128i ev = __lsx_vpickev_w(sum128, sum128); - __m128i od = __lsx_vpickod_w(sum128, sum128); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - __m128i ev = __lsx_vpickev_w(a, a); - __m128i od = __lsx_vpickod_w(a, a); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = lasx_set_d( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - - __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); - const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); - bytes = __lasx_xvor_v(bytes, bit_mask); - return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { - const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); - __m128i hi = __lsx_vsrli_h(lo, 4); - return __lasx_xvandi_b(MM256_SET_M128I(hi, lo), 0xf); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - __m256i v = __lasx_xvpackod_h(x, x); - __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); - return __lasx_xvffint_s_w(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - // Perform multiplication and create 16-bit values - const __m256i dot = lasx_maddubs_h(ax, sy); - return sum_i16_pairs_float(dot); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - - // Get absolute values of x vectors - const __m256i ax = __lasx_xvsigncov_b(x, x); - // Sign the values of the y vectors - const __m256i sy = __lasx_xvsigncov_b(x, y); - - return mul_sum_us8_pairs_float(ax, sy); -} - -static inline __m128i packNibbles( __m256i bytes ) { - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); - __m256i high = __lasx_xvandn_v(lowByte, bytes); - __m256i low = __lasx_xvand_v(lowByte, bytes); - high = __lasx_xvsrli_h(high, 4); - bytes = __lasx_xvor_v(low, high); - // Compress uint16_t lanes into bytes - __m128i *r0 = (__m128i *)&bytes; - __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); - __m128i *r1 = (__m128i *)&tmp_h128; - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2, tmp3; - - tmp = __lsx_vmax_h(zero, *r0); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(zero, *r1); - tmp3 = __lsx_vsat_hu(tmp, 7); - return __lsx_vpickev_b(tmp3, tmp2); -} -#endif - // reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -6368,7 +5953,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i all_scales = lasx_ext8_16(scales8); const __m128i l_scales = lasx_extracti128(all_scales, 0); const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; __m256i sumi = __lasx_xvldi(0); @@ -6790,8 +6375,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r summs += dmin * smin; const __m128i q2bits = __lsx_vld((const __m128i*)q2, 0); - const __m256i q2_0 = __lasx_xvand_v(MM256_SET_M128I(__lsx_vsrli_h(q2bits, 2), q2bits), m3); - const __m256i q2_1 = __lasx_xvand_v(MM256_SET_M128I(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3); + const __m256i q2_0 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 2), q2bits), m3); + const __m256i q2_1 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3); const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0); const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0); @@ -7491,7 +7076,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i all_scales = lasx_ext8_16(scales128); const __m128i l_scales = lasx_extracti128(all_scales, 0); const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; // high bit const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); @@ -8041,14 +7626,14 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const uint8_t * restrict q3 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const __m256i scale_0 = MM256_SET_M128I(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8)); - const __m256i scale_1 = MM256_SET_M128I(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8)); + const __m256i scale_0 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8)); + const __m256i scale_1 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8)); memcpy(&aux64, x[i].hmask, 8); __m128i haux = __lsx_vinsgr2vr_d(haux, aux64, 0); haux = __lsx_vinsgr2vr_d(haux, aux64 >> 1, 1); - __m256i q3h_0 = MM256_SET_M128I(__lsx_vsrli_h(haux, 2), haux); + __m256i q3h_0 = lasx_insertf128(__lsx_vsrli_h(haux, 2), haux); __m256i q3h_1 = __lasx_xvsrli_h(q3h_0, 4); q3h_0 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_0, m1), 2); q3h_1 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_1, m1), 2); @@ -8057,7 +7642,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m128i q3bits = __lsx_vld((const __m128i*)q3, 0); // prepare low and high bits - const __m256i q3aux = MM256_SET_M128I(__lsx_vsrli_h(q3bits, 2), q3bits); + const __m256i q3aux = lasx_insertf128(__lsx_vsrli_h(q3bits, 2), q3bits); const __m256i q3l_0 = __lasx_xvand_v(q3aux, m3); const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3aux, 4), m3); @@ -8602,7 +8187,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); + const __m256i scales = lasx_insertf128(sc128, sc128); __m256i sumi = __lasx_xvldi(0); @@ -9589,7 +9174,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); + const __m256i scales = lasx_insertf128(sc128, sc128); const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); __m256i hmask = mone; @@ -10023,14 +9608,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); - const __m256i scale_l = MM256_SET_M128I(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0])); - const __m256i scale_h = MM256_SET_M128I(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2])); + const __m256i scale_l = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0])); + const __m256i scale_h = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2])); int64_t aux64; memcpy(&aux64, x[i].qh, 8); __m128i haux128 = __lsx_vinsgr2vr_d(haux128, aux64, 0); haux128 = __lsx_vinsgr2vr_d(haux128, aux64 >> 1, 1); - const __m256i haux256 = MM256_SET_M128I(__lsx_vsrli_h(haux128, 2), haux128); + const __m256i haux256 = lasx_insertf128(__lsx_vsrli_h(haux128, 2), haux128); const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvandn_v(haux256, mone), 4); const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvandn_v(__lasx_xvsrli_h(haux256, 4), mone), 4); @@ -11122,8 +10707,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); const __m128i q4bitsH = __lsx_vld((const __m128i*)qh, 0); - const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(MM256_SET_M128I(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4); - const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(MM256_SET_M128I(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4); + const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4); const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_1); @@ -11782,7 +11367,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * const __m128i odd_bits = lsx_shuffle_b(bit_helper, partial_sign_bits_for_counting); const __m128i full_sign_bits = __lsx_vor_v(partial_sign_bits, odd_bits); - const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits); + const __m256i full_signs = lasx_insertf128(full_sign_bits, full_sign_bits); const __m256i q8_1 = __lasx_xvld((const __m256i *)y[i].qs, 0); const __m256i q8_2 = __lasx_xvld((const __m256i *)(y[i].qs+32), 0); @@ -11803,8 +11388,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1)); - const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1)); + const __m256i sc1 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[0] & 0xf)+1)); + const __m256i sc2 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[1] & 0xf)+1)); const __m256i sum = __lasx_xvadd_w(lasx_madd_h(sc1, dot1), lasx_madd_h(sc2, dot2)); @@ -11870,8 +11455,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); - const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); - const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); + const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); + const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); __m256i signs; signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); @@ -13862,9 +13447,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0); const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0); const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0); - const __m256i q4b_1 = MM256_SET_M128I(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), + const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), + const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); @@ -14126,7 +13711,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * tmp4 = __lsx_vand_v(tmp0, mask); tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - const __m256i q4b_1 = MM256_SET_M128I(tmp3, tmp4); + const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); tmp0 = __lsx_vori_b(tmp2, 0x10); @@ -14140,7 +13725,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * tmp4 = __lsx_vand_v(tmp0, mask); tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - const __m256i q4b_2 = MM256_SET_M128I(tmp3, tmp4); + const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); diff --git a/ggml.c b/ggml.c index 04e626453..0d2c3c86a 100644 --- a/ggml.c +++ b/ggml.c @@ -1528,24 +1528,6 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { #define GGML_SIMD // F32 LASX - -typedef union -{ - int32_t i; - float f; -} FloatInt; -/* float type data load instructions */ -static __m128 __lsx_vreplfr2vr_s(float val) -{ - FloatInt fi_tmpval = {.f = val}; - return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); -} - -static __m256 __lasx_xvreplfr2vr_s(float val) -{ - FloatInt fi_tmpval = {.f = val}; - return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); -} #define GGML_F32_STEP 32 #define GGML_F32_EPR 8 @@ -1597,7 +1579,7 @@ do { \ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) -static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { +static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) { float tmp[8]; for (int i = 0; i < 8; i++) { @@ -1606,7 +1588,7 @@ static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { return (__m256)__lasx_xvld(tmp, 0); } -static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { +static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) { float arr[8]; __lasx_xvst(y, arr, 0); @@ -1614,8 +1596,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { for (int i = 0; i < 8; i++) x[i] = GGML_FP32_TO_FP16(arr[i]); } -#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) #define GGML_F32Cx8_FMA GGML_F32x8_FMA #define GGML_F32Cx8_ADD __lasx_xvfadd_s @@ -1632,7 +1614,6 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE - #elif defined(__loongarch_sx) #define GGML_SIMD