This commit is contained in:
junchao-loongson 2024-05-19 16:32:40 +08:00
parent 3b6199ba3c
commit 8a0d9a304f
4 changed files with 442 additions and 463 deletions

View file

@ -1132,7 +1132,7 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "loongarch64")
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
message(STATUS "loongarch64 detected")
list(APPEND ARCH_FLAGS -march=loongarch64)

View file

@ -452,6 +452,419 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
#endif
#endif
#if defined(__loongarch_asx)
typedef union {
int32_t i;
float f;
} ft_union;
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
}
static __m256 __lasx_xvreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
}
#ifdef __clang__
#define VREGS_PREFIX "$vr"
#define XREGS_PREFIX "$xr"
#else // GCC
#define VREGS_PREFIX "$f"
#define XREGS_PREFIX "$f"
#endif
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
// Convert __m128i to __m256i
static inline __m256i ____m256i(__m128i in) {
__m256i out = __lasx_xvldi(0);
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " XREGS_PREFIX"\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " VREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
: [out] "+f" (out) : [in] "f" (in)
);
return out;
}
// Convert two __m128i to __m256i
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
__m256i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".ifnc %[out], %[hi] \n\t"
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
" xvori.b $xr\\i, $xr\\j, 0 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".endif \n\t"
: [out] "=f" (out), [hi] "+f" (inhi)
: [lo] "f" (inlo)
);
return out;
}
// Convert __m256i low part to __m128i
static inline __m128i lasx_extracti128_lo(__m256i in) {
__m128i out;
__asm__ volatile (
".ifnc %[out], %[in] \n\t"
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
" vori.b $vr\\i, $vr\\j, 0 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".endif \n\t"
: [out] "=f" (out) : [in] "f" (in)
);
return out;
}
// Convert __m256i high part to __m128i
static inline __m128i lasx_extracti128_hi(__m256i in) {
__m128i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
: [out] "=f" (out) : [in] "f" (in)
);
return out;
}
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
return (__m256i)__ret;
}
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
v4i64 __ret = {d, c, b, a};
return (__m256i)__ret;
}
static __m256i lasx_insertf128( __m128i x, __m128i y) {
return lasx_set_q(x, y);
}
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
zero = __lsx_vldi(0);
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
return __lsx_vshuf_b(a, zero, tmp2);
}
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
__m256i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lasx_xvreplgr2vr_b(f);
zero = __lasx_xvldi(0);
tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
return __lasx_xvshuf_b(a, zero, tmp2);
}
static __m256i lasx_extu8_16(__m128i a) {
__m128i zero = __lsx_vldi(0);
__m128i vlo = __lsx_vilvl_b(zero, a);
__m128i vhi = __lsx_vilvh_b(zero, a);
return lasx_set_q(vhi, vlo);
}
static __m256i lasx_ext8_16(__m128i a) {
__m128i sign = __lsx_vslti_b(a, 0);
__m128i vlo = __lsx_vilvl_b(sign, a);
__m128i vhi = __lsx_vilvh_b(sign, a);
return lasx_set_q(vhi, vlo);
}
static __m256i lasx_ext16_32(__m128i a) {
__m256i tmp1;
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
return tmp1;
}
static __m128i lasx_extracti128( __m256i a, int pos) {
__m128i ret;
if( pos == 0)
{
ret = lasx_extracti128_lo(a);
} else {
ret = lasx_extracti128_hi(a);
}
return ret;
}
static __m128 lasx_extractf128( __m256 a, int pos) {
__m128 ret;
if( pos == 0)
{
ret = (__m128)lasx_extracti128_lo((__m256i)a);
} else {
ret = (__m128)lasx_extracti128_hi((__m256i)a);
}
return ret;
}
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
return __lsx_vfadd_s(tmp1, tmp2);
}
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_h_b(a, b);
tmp2 = __lasx_xvmulwod_h_b(a, b);
return __lasx_xvsadd_h(tmp1, tmp2);
}
static __m256i lasx_madd_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_w_h(a, b);
tmp2 = __lasx_xvmulwod_w_h(a, b);
return __lasx_xvadd_w(tmp1, tmp2);
}
static __m256i lasx_packs_w(__m256i a, __m256i b) {
__m256i tmp, tmp1;
tmp = __lasx_xvsat_w(a, 15);
tmp1 = __lasx_xvsat_w(b, 15);
return __lasx_xvpickev_h(tmp1, tmp);
}
static __m256i lasx_packs_h(__m256i a, __m256i b) {
__m256i tmp, tmp1;
tmp = __lasx_xvsat_h(a, 7);
tmp1 = __lasx_xvsat_h(b, 7);
return __lasx_xvpickev_b(tmp1, tmp);
}
static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}
static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}
static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
return __lsx_vadd_w(tmp1, tmp2);
}
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
const __m128i ax = __lsx_vsigncov_b(x, x);
// Sign the values of the y vectors
const __m128i sy = __lsx_vsigncov_b(x, y);
// Perform multiplication and create 16-bit values
const __m128i dot = lsx_maddubs_h(ax, sy);
const __m128i ones = __lsx_vreplgr2vr_h(1);
return lsx_madd_h(ones, dot);
}
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
tmp.i = __lsx_vpickve2gr_w(res, 0);
return tmp.f;
}
// horizontally add 8 int32_t
static inline int hsum_i32_8(const __m256i a) {
__m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
__m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
__m128i tmp1_128 = lasx_extracti128_lo(tmp1);
__m128i tmp2_128 = lasx_extracti128_lo(tmp2);
__m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
__m128i ev = __lsx_vpickev_w(sum128, sum128);
__m128i od = __lsx_vpickod_w(sum128, sum128);
__m128i sum64 = __lsx_vadd_w(ev, od);
int sum64_1, sum64_2;
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
return sum64_1 + sum64_2;
}
// horizontally add 4 int32_t
static inline int hsum_i32_4(const __m128i a) {
__m128i ev = __lsx_vpickev_w(a, a);
__m128i od = __lsx_vpickod_w(a, a);
__m128i sum64 = __lsx_vadd_w(ev, od);
int sum64_1, sum64_2;
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
return sum64_1 + sum64_2;
}
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m256i shuf_mask = lasx_set_d(
0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000);
__m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
bytes = __lasx_xvor_v(bytes, bit_mask);
return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
__m128i hi = __lsx_vsrli_h(lo, 4);
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
__m256i v = __lasx_xvpackod_h(x, x);
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
return __lasx_xvffint_s_w(summed_pairs);
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
// Perform multiplication and create 16-bit values
const __m256i dot = lasx_maddubs_h(ax, sy);
return sum_i16_pairs_float(dot);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
// Get absolute values of x vectors
const __m256i ax = __lasx_xvsigncov_b(x, x);
// Sign the values of the y vectors
const __m256i sy = __lasx_xvsigncov_b(x, y);
return mul_sum_us8_pairs_float(ax, sy);
}
static inline __m128i packNibbles( __m256i bytes ) {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
__m256i high = __lasx_xvandn_v(lowByte, bytes);
__m256i low = __lasx_xvand_v(lowByte, bytes);
high = __lasx_xvsrli_h(high, 4);
bytes = __lasx_xvor_v(low, high);
// Compress uint16_t lanes into bytes
__m128i *r0 = (__m128i *)&bytes;
__m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
__m128i *r1 = (__m128i *)&tmp_h128;
__m128i zero = __lsx_vldi(0);
__m128i tmp, tmp2, tmp3;
tmp = __lsx_vmax_h(zero, *r0);
tmp2 = __lsx_vsat_hu(tmp, 7);
tmp = __lsx_vmax_h(zero, *r1);
tmp3 = __lsx_vsat_hu(tmp, 7);
return __lsx_vpickev_b(tmp3, tmp2);
}
#endif
#ifdef __F16C__
#ifdef _MSC_VER

View file

@ -262,421 +262,6 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#endif
#if defined(__loongarch_asx)
typedef union {
int32_t i;
float f;
} ft_union;
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
}
static __m256 __lasx_xvreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
}
#ifdef __clang__
#define VREGS_PREFIX "$vr"
#define XREGS_PREFIX "$xr"
#else // GCC
#define VREGS_PREFIX "$f"
#define XREGS_PREFIX "$f"
#endif
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
// Convert __m128i to __m256i
static inline __m256i ____m256i(__m128i in) {
__m256i out = __lasx_xvldi(0);
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " XREGS_PREFIX"\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " VREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
: [out] "+f" (out) : [in] "f" (in)
);
return out;
}
// Convert two __m128i to __m256i
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
__m256i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".ifnc %[out], %[hi] \n\t"
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
" xvori.b $xr\\i, $xr\\j, 0 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".endif \n\t"
: [out] "=f" (out), [hi] "+f" (inhi)
: [lo] "f" (inlo)
);
return out;
}
// Convert __m256i low part to __m128i
static inline __m128i lasx_extracti128_lo(__m256i in) {
__m128i out;
__asm__ volatile (
".ifnc %[out], %[in] \n\t"
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
" vori.b $vr\\i, $vr\\j, 0 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".endif \n\t"
: [out] "=f" (out) : [in] "f" (in)
);
return out;
}
// Convert __m256i high part to __m128i
static inline __m128i lasx_extracti128_hi(__m256i in) {
__m128i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
: [out] "=f" (out) : [in] "f" (in)
);
return out;
}
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
return (__m256i)__ret;
}
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
v4i64 __ret = {d, c, b, a};
return (__m256i)__ret;
}
static __m256i lasx_insertf128( __m128i x, __m128i y) {
return lasx_set_q(x, y);
}
#undef MM256_SET_M128I
#define MM256_SET_M128I(a, b) lasx_insertf128((a), (b))
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
zero = __lsx_vldi(0);
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
return __lsx_vshuf_b(a, zero, tmp2);
}
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
__m256i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lasx_xvreplgr2vr_b(f);
zero = __lasx_xvldi(0);
tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
return __lasx_xvshuf_b(a, zero, tmp2);
}
static __m256i lasx_extu8_16(__m128i a) {
__m128i zero = __lsx_vldi(0);
__m128i vlo = __lsx_vilvl_b(zero, a);
__m128i vhi = __lsx_vilvh_b(zero, a);
return lasx_set_q(vhi, vlo);
}
static __m256i lasx_ext8_16(__m128i a) {
__m128i sign = __lsx_vslti_b(a, 0);
__m128i vlo = __lsx_vilvl_b(sign, a);
__m128i vhi = __lsx_vilvh_b(sign, a);
return lasx_set_q(vhi, vlo);
}
static __m256i lasx_ext16_32(__m128i a) {
__m256i tmp1;
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
return tmp1;
}
static __m128i lasx_extracti128( __m256i a, int pos) {
__m128i ret;
if( pos == 0)
{
ret = lasx_extracti128_lo(a);
} else {
ret = lasx_extracti128_hi(a);
}
return ret;
}
static __m128 lasx_extractf128( __m256 a, int pos) {
__m128 ret;
if( pos == 0)
{
ret = (__m128)lasx_extracti128_lo((__m256i)a);
} else {
ret = (__m128)lasx_extracti128_hi((__m256i)a);
}
return ret;
}
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
return __lsx_vfadd_s(tmp1, tmp2);
}
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_h_b(a, b);
tmp2 = __lasx_xvmulwod_h_b(a, b);
return __lasx_xvsadd_h(tmp1, tmp2);
}
static __m256i lasx_madd_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_w_h(a, b);
tmp2 = __lasx_xvmulwod_w_h(a, b);
return __lasx_xvadd_w(tmp1, tmp2);
}
static __m256i lasx_packs_w(__m256i a, __m256i b) {
__m256i tmp, tmp1;
tmp = __lasx_xvsat_w(a, 15);
tmp1 = __lasx_xvsat_w(b, 15);
return __lasx_xvpickev_h(tmp1, tmp);
}
static __m256i lasx_packs_h(__m256i a, __m256i b) {
__m256i tmp, tmp1;
tmp = __lasx_xvsat_h(a, 7);
tmp1 = __lasx_xvsat_h(b, 7);
return __lasx_xvpickev_b(tmp1, tmp);
}
static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}
static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}
static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
return __lsx_vadd_w(tmp1, tmp2);
}
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
const __m128i ax = __lsx_vsigncov_b(x, x);
// Sign the values of the y vectors
const __m128i sy = __lsx_vsigncov_b(x, y);
// Perform multiplication and create 16-bit values
const __m128i dot = lsx_maddubs_h(ax, sy);
const __m128i ones = __lsx_vreplgr2vr_h(1);
return lsx_madd_h(ones, dot);
}
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
tmp.i = __lsx_vpickve2gr_w(res, 0);
return tmp.f;
}
// horizontally add 8 int32_t
static inline int hsum_i32_8(const __m256i a) {
__m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
__m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
__m128i tmp1_128 = lasx_extracti128_lo(tmp1);
__m128i tmp2_128 = lasx_extracti128_lo(tmp2);
__m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
__m128i ev = __lsx_vpickev_w(sum128, sum128);
__m128i od = __lsx_vpickod_w(sum128, sum128);
__m128i sum64 = __lsx_vadd_w(ev, od);
int sum64_1, sum64_2;
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
return sum64_1 + sum64_2;
}
// horizontally add 4 int32_t
static inline int hsum_i32_4(const __m128i a) {
__m128i ev = __lsx_vpickev_w(a, a);
__m128i od = __lsx_vpickod_w(a, a);
__m128i sum64 = __lsx_vadd_w(ev, od);
int sum64_1, sum64_2;
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
return sum64_1 + sum64_2;
}
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m256i shuf_mask = lasx_set_d(
0x0303030303030303, 0x0202020202020202,
0x0101010101010101, 0x0000000000000000);
__m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
bytes = __lasx_xvor_v(bytes, bit_mask);
return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
__m128i hi = __lsx_vsrli_h(lo, 4);
return __lasx_xvandi_b(MM256_SET_M128I(hi, lo), 0xf);
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
__m256i v = __lasx_xvpackod_h(x, x);
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
return __lasx_xvffint_s_w(summed_pairs);
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
// Perform multiplication and create 16-bit values
const __m256i dot = lasx_maddubs_h(ax, sy);
return sum_i16_pairs_float(dot);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
// Get absolute values of x vectors
const __m256i ax = __lasx_xvsigncov_b(x, x);
// Sign the values of the y vectors
const __m256i sy = __lasx_xvsigncov_b(x, y);
return mul_sum_us8_pairs_float(ax, sy);
}
static inline __m128i packNibbles( __m256i bytes ) {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
__m256i high = __lasx_xvandn_v(lowByte, bytes);
__m256i low = __lasx_xvand_v(lowByte, bytes);
high = __lasx_xvsrli_h(high, 4);
bytes = __lasx_xvor_v(low, high);
// Compress uint16_t lanes into bytes
__m128i *r0 = (__m128i *)&bytes;
__m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
__m128i *r1 = (__m128i *)&tmp_h128;
__m128i zero = __lsx_vldi(0);
__m128i tmp, tmp2, tmp3;
tmp = __lsx_vmax_h(zero, *r0);
tmp2 = __lsx_vsat_hu(tmp, 7);
tmp = __lsx_vmax_h(zero, *r1);
tmp3 = __lsx_vsat_hu(tmp, 7);
return __lsx_vpickev_b(tmp3, tmp2);
}
#endif
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
@ -6368,7 +5953,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i all_scales = lasx_ext8_16(scales8);
const __m128i l_scales = lasx_extracti128(all_scales, 0);
const __m128i h_scales = lasx_extracti128(all_scales, 1);
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
__m256i sumi = __lasx_xvldi(0);
@ -6790,8 +6375,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
summs += dmin * smin;
const __m128i q2bits = __lsx_vld((const __m128i*)q2, 0);
const __m256i q2_0 = __lasx_xvand_v(MM256_SET_M128I(__lsx_vsrli_h(q2bits, 2), q2bits), m3);
const __m256i q2_1 = __lasx_xvand_v(MM256_SET_M128I(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3);
const __m256i q2_0 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 2), q2bits), m3);
const __m256i q2_1 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3);
const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0);
const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0);
@ -7491,7 +7076,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i all_scales = lasx_ext8_16(scales128);
const __m128i l_scales = lasx_extracti128(all_scales, 0);
const __m128i h_scales = lasx_extracti128(all_scales, 1);
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
// high bit
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
@ -8041,14 +7626,14 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * restrict q3 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
const __m256i scale_0 = MM256_SET_M128I(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8));
const __m256i scale_1 = MM256_SET_M128I(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8));
const __m256i scale_0 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8));
const __m256i scale_1 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8));
memcpy(&aux64, x[i].hmask, 8);
__m128i haux = __lsx_vinsgr2vr_d(haux, aux64, 0);
haux = __lsx_vinsgr2vr_d(haux, aux64 >> 1, 1);
__m256i q3h_0 = MM256_SET_M128I(__lsx_vsrli_h(haux, 2), haux);
__m256i q3h_0 = lasx_insertf128(__lsx_vsrli_h(haux, 2), haux);
__m256i q3h_1 = __lasx_xvsrli_h(q3h_0, 4);
q3h_0 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_0, m1), 2);
q3h_1 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_1, m1), 2);
@ -8057,7 +7642,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m128i q3bits = __lsx_vld((const __m128i*)q3, 0);
// prepare low and high bits
const __m256i q3aux = MM256_SET_M128I(__lsx_vsrli_h(q3bits, 2), q3bits);
const __m256i q3aux = lasx_insertf128(__lsx_vsrli_h(q3bits, 2), q3bits);
const __m256i q3l_0 = __lasx_xvand_v(q3aux, m3);
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3aux, 4), m3);
@ -8602,7 +8187,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
const __m256i scales = MM256_SET_M128I(sc128, sc128);
const __m256i scales = lasx_insertf128(sc128, sc128);
__m256i sumi = __lasx_xvldi(0);
@ -9589,7 +9174,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
const __m256i scales = MM256_SET_M128I(sc128, sc128);
const __m256i scales = lasx_insertf128(sc128, sc128);
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
__m256i hmask = mone;
@ -10023,14 +9608,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0);
const __m256i scale_l = MM256_SET_M128I(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0]));
const __m256i scale_h = MM256_SET_M128I(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2]));
const __m256i scale_l = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0]));
const __m256i scale_h = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2]));
int64_t aux64;
memcpy(&aux64, x[i].qh, 8);
__m128i haux128 = __lsx_vinsgr2vr_d(haux128, aux64, 0);
haux128 = __lsx_vinsgr2vr_d(haux128, aux64 >> 1, 1);
const __m256i haux256 = MM256_SET_M128I(__lsx_vsrli_h(haux128, 2), haux128);
const __m256i haux256 = lasx_insertf128(__lsx_vsrli_h(haux128, 2), haux128);
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvandn_v(haux256, mone), 4);
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvandn_v(__lasx_xvsrli_h(haux256, 4), mone), 4);
@ -11122,8 +10707,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0);
const __m128i q4bitsH = __lsx_vld((const __m128i*)qh, 0);
const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(MM256_SET_M128I(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4);
const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(MM256_SET_M128I(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4);
const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4);
const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4);
const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_1);
@ -11782,7 +11367,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i odd_bits = lsx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
const __m128i full_sign_bits = __lsx_vor_v(partial_sign_bits, odd_bits);
const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits);
const __m256i full_signs = lasx_insertf128(full_sign_bits, full_sign_bits);
const __m256i q8_1 = __lasx_xvld((const __m256i *)y[i].qs, 0);
const __m256i q8_2 = __lasx_xvld((const __m256i *)(y[i].qs+32), 0);
@ -11803,8 +11388,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
const __m256i sc1 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[0] & 0xf)+1));
const __m256i sc2 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[1] & 0xf)+1));
const __m256i sum = __lasx_xvadd_w(lasx_madd_h(sc1, dot1), lasx_madd_h(sc2, dot2));
@ -11870,8 +11455,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
__m256i signs;
signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
@ -13862,9 +13447,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0);
const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0);
const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0);
const __m256i q4b_1 = MM256_SET_M128I(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
const __m256i q4b_2 = MM256_SET_M128I(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
@ -14126,7 +13711,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
tmp4 = __lsx_vand_v(tmp0, mask);
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
const __m256i q4b_1 = MM256_SET_M128I(tmp3, tmp4);
const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
tmp0 = __lsx_vori_b(tmp2, 0x10);
@ -14140,7 +13725,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
tmp4 = __lsx_vand_v(tmp0, mask);
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
const __m256i q4b_2 = MM256_SET_M128I(tmp3, tmp4);
const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);

27
ggml.c
View file

@ -1528,24 +1528,6 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
#define GGML_SIMD
// F32 LASX
typedef union
{
int32_t i;
float f;
} FloatInt;
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(float val)
{
FloatInt fi_tmpval = {.f = val};
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
}
static __m256 __lasx_xvreplfr2vr_s(float val)
{
FloatInt fi_tmpval = {.f = val};
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
}
#define GGML_F32_STEP 32
#define GGML_F32_EPR 8
@ -1597,7 +1579,7 @@ do { \
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
float tmp[8];
for (int i = 0; i < 8; i++) {
@ -1606,7 +1588,7 @@ static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
return (__m256)__lasx_xvld(tmp, 0);
}
static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
float arr[8];
__lasx_xvst(y, arr, 0);
@ -1614,8 +1596,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
for (int i = 0; i < 8; i++)
x[i] = GGML_FP32_TO_FP16(arr[i]);
}
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
#define GGML_F32Cx8_ADD __lasx_xvfadd_s
@ -1632,7 +1614,6 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
#elif defined(__loongarch_sx)
#define GGML_SIMD