update
This commit is contained in:
parent
3b6199ba3c
commit
8a0d9a304f
4 changed files with 442 additions and 463 deletions
|
@ -1132,7 +1132,7 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
|||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
|
||||
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "loongarch64")
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
message(STATUS "loongarch64 detected")
|
||||
|
||||
list(APPEND ARCH_FLAGS -march=loongarch64)
|
||||
|
|
413
ggml-impl.h
413
ggml-impl.h
|
@ -452,6 +452,419 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
|
||||
typedef union {
|
||||
int32_t i;
|
||||
float f;
|
||||
} ft_union;
|
||||
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(float val) {
|
||||
ft_union fi_tmpval = {.f = val};
|
||||
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
|
||||
static __m256 __lasx_xvreplfr2vr_s(float val) {
|
||||
ft_union fi_tmpval = {.f = val};
|
||||
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#define VREGS_PREFIX "$vr"
|
||||
#define XREGS_PREFIX "$xr"
|
||||
#else // GCC
|
||||
#define VREGS_PREFIX "$f"
|
||||
#define XREGS_PREFIX "$f"
|
||||
#endif
|
||||
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
||||
// Convert __m128i to __m256i
|
||||
static inline __m256i ____m256i(__m128i in) {
|
||||
__m256i out = __lasx_xvldi(0);
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " XREGS_PREFIX"\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
: [out] "+f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
// Convert two __m128i to __m256i
|
||||
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
|
||||
__m256i out;
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".ifnc %[out], %[hi] \n\t"
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvori.b $xr\\i, $xr\\j, 0 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".endif \n\t"
|
||||
: [out] "=f" (out), [hi] "+f" (inhi)
|
||||
: [lo] "f" (inlo)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
// Convert __m256i low part to __m128i
|
||||
static inline __m128i lasx_extracti128_lo(__m256i in) {
|
||||
__m128i out;
|
||||
__asm__ volatile (
|
||||
".ifnc %[out], %[in] \n\t"
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||
" vori.b $vr\\i, $vr\\j, 0 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".endif \n\t"
|
||||
: [out] "=f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
// Convert __m256i high part to __m128i
|
||||
static inline __m128i lasx_extracti128_hi(__m256i in) {
|
||||
__m128i out;
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
: [out] "=f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
|
||||
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
|
||||
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
|
||||
return (__m256i)__ret;
|
||||
}
|
||||
|
||||
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
||||
v4i32 __ret = {d, c, b, a};
|
||||
return (__m128i)__ret;
|
||||
}
|
||||
|
||||
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
|
||||
v4i64 __ret = {d, c, b, a};
|
||||
return (__m256i)__ret;
|
||||
}
|
||||
|
||||
static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
||||
return lasx_set_q(x, y);
|
||||
}
|
||||
|
||||
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
||||
__m128i mask_f, zero, tmp0, tmp2, mask;
|
||||
int f = 0x8f;
|
||||
mask_f = __lsx_vreplgr2vr_b(f);
|
||||
zero = __lsx_vldi(0);
|
||||
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
||||
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
||||
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
||||
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
||||
return __lsx_vshuf_b(a, zero, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
||||
__m256i mask_f, zero, tmp0, tmp2, mask;
|
||||
int f = 0x8f;
|
||||
mask_f = __lasx_xvreplgr2vr_b(f);
|
||||
zero = __lasx_xvldi(0);
|
||||
tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
||||
tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
||||
mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
|
||||
tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
|
||||
return __lasx_xvshuf_b(a, zero, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_extu8_16(__m128i a) {
|
||||
__m128i zero = __lsx_vldi(0);
|
||||
__m128i vlo = __lsx_vilvl_b(zero, a);
|
||||
__m128i vhi = __lsx_vilvh_b(zero, a);
|
||||
return lasx_set_q(vhi, vlo);
|
||||
}
|
||||
|
||||
static __m256i lasx_ext8_16(__m128i a) {
|
||||
__m128i sign = __lsx_vslti_b(a, 0);
|
||||
__m128i vlo = __lsx_vilvl_b(sign, a);
|
||||
__m128i vhi = __lsx_vilvh_b(sign, a);
|
||||
return lasx_set_q(vhi, vlo);
|
||||
}
|
||||
|
||||
static __m256i lasx_ext16_32(__m128i a) {
|
||||
__m256i tmp1;
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
|
||||
return tmp1;
|
||||
}
|
||||
|
||||
static __m128i lasx_extracti128( __m256i a, int pos) {
|
||||
__m128i ret;
|
||||
if( pos == 0)
|
||||
{
|
||||
ret = lasx_extracti128_lo(a);
|
||||
} else {
|
||||
ret = lasx_extracti128_hi(a);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __m128 lasx_extractf128( __m256 a, int pos) {
|
||||
__m128 ret;
|
||||
if( pos == 0)
|
||||
{
|
||||
ret = (__m128)lasx_extracti128_lo((__m256i)a);
|
||||
} else {
|
||||
ret = (__m128)lasx_extracti128_hi((__m256i)a);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
||||
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
||||
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
||||
return __lsx_vadd_h(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
||||
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
||||
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
||||
return __lsx_vadd_w(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
||||
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
||||
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
||||
|
||||
return __lsx_vfadd_s(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
||||
__m256i tmp1, tmp2;
|
||||
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
||||
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
||||
return __lasx_xvsadd_h(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_madd_h(__m256i a, __m256i b) {
|
||||
__m256i tmp1, tmp2;
|
||||
tmp1 = __lasx_xvmulwev_w_h(a, b);
|
||||
tmp2 = __lasx_xvmulwod_w_h(a, b);
|
||||
return __lasx_xvadd_w(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_packs_w(__m256i a, __m256i b) {
|
||||
__m256i tmp, tmp1;
|
||||
tmp = __lasx_xvsat_w(a, 15);
|
||||
tmp1 = __lasx_xvsat_w(b, 15);
|
||||
return __lasx_xvpickev_h(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m256i lasx_packs_h(__m256i a, __m256i b) {
|
||||
__m256i tmp, tmp1;
|
||||
tmp = __lasx_xvsat_h(a, 7);
|
||||
tmp1 = __lasx_xvsat_h(b, 7);
|
||||
return __lasx_xvpickev_b(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m128i lsx_packs_w(__m128i a, __m128i b) {
|
||||
__m128i tmp, tmp1;
|
||||
tmp = __lsx_vsat_w(a, 15);
|
||||
tmp1 = __lsx_vsat_w(b, 15);
|
||||
return __lsx_vpickev_h(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
||||
__m128i tmp, tmp1;
|
||||
tmp = __lsx_vsat_h(a, 7);
|
||||
tmp1 = __lsx_vsat_h(b, 7);
|
||||
return __lsx_vpickev_b(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
||||
__m128i tmp, tmp1;
|
||||
tmp = __lsx_vsat_hu(a, 7);
|
||||
tmp1 = __lsx_vsat_hu(b, 7);
|
||||
return __lsx_vpickev_b(tmp1, tmp);
|
||||
}
|
||||
|
||||
|
||||
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
|
||||
__m128i tmp1, tmp2;
|
||||
tmp1 = __lsx_vmulwev_h_b(a, b);
|
||||
tmp2 = __lsx_vmulwod_h_b(a, b);
|
||||
return __lsx_vsadd_h(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m128i lsx_madd_h(__m128i a, __m128i b) {
|
||||
__m128i tmp1, tmp2;
|
||||
tmp1 = __lsx_vmulwev_w_h(a, b);
|
||||
tmp2 = __lsx_vmulwod_w_h(a, b);
|
||||
return __lsx_vadd_w(tmp1, tmp2);
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice
|
||||
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||
// Get absolute values of x vectors
|
||||
const __m128i ax = __lsx_vsigncov_b(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m128i sy = __lsx_vsigncov_b(x, y);
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m128i dot = lsx_maddubs_h(ax, sy);
|
||||
const __m128i ones = __lsx_vreplgr2vr_h(1);
|
||||
return lsx_madd_h(ones, dot);
|
||||
}
|
||||
|
||||
// horizontally add 8 floats
|
||||
static inline float hsum_float_8(const __m256 x) {
|
||||
__m128 res = lasx_extractf128(x, 1);
|
||||
ft_union tmp;
|
||||
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
||||
tmp.i = __lsx_vpickve2gr_w(res, 0);
|
||||
return tmp.f;
|
||||
}
|
||||
|
||||
// horizontally add 8 int32_t
|
||||
static inline int hsum_i32_8(const __m256i a) {
|
||||
|
||||
__m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
|
||||
__m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
|
||||
|
||||
__m128i tmp1_128 = lasx_extracti128_lo(tmp1);
|
||||
__m128i tmp2_128 = lasx_extracti128_lo(tmp2);
|
||||
|
||||
__m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
|
||||
|
||||
__m128i ev = __lsx_vpickev_w(sum128, sum128);
|
||||
__m128i od = __lsx_vpickod_w(sum128, sum128);
|
||||
__m128i sum64 = __lsx_vadd_w(ev, od);
|
||||
|
||||
int sum64_1, sum64_2;
|
||||
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
||||
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
||||
|
||||
return sum64_1 + sum64_2;
|
||||
}
|
||||
|
||||
// horizontally add 4 int32_t
|
||||
static inline int hsum_i32_4(const __m128i a) {
|
||||
__m128i ev = __lsx_vpickev_w(a, a);
|
||||
__m128i od = __lsx_vpickod_w(a, a);
|
||||
__m128i sum64 = __lsx_vadd_w(ev, od);
|
||||
|
||||
int sum64_1, sum64_2;
|
||||
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
||||
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
||||
|
||||
return sum64_1 + sum64_2;
|
||||
}
|
||||
|
||||
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
||||
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
||||
|
||||
uint32_t x32;
|
||||
memcpy(&x32, x, sizeof(uint32_t));
|
||||
const __m256i shuf_mask = lasx_set_d(
|
||||
0x0303030303030303, 0x0202020202020202,
|
||||
0x0101010101010101, 0x0000000000000000);
|
||||
|
||||
__m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
|
||||
const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
|
||||
bytes = __lasx_xvor_v(bytes, bit_mask);
|
||||
return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
|
||||
}
|
||||
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
||||
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
|
||||
__m128i hi = __lsx_vsrli_h(lo, 4);
|
||||
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
|
||||
}
|
||||
|
||||
// add int16_t pairwise and return as float vector
|
||||
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
||||
__m256i v = __lasx_xvpackod_h(x, x);
|
||||
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
|
||||
return __lasx_xvffint_s_w(summed_pairs);
|
||||
}
|
||||
|
||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m256i dot = lasx_maddubs_h(ax, sy);
|
||||
return sum_i16_pairs_float(dot);
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice and return as float vector
|
||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||
|
||||
// Get absolute values of x vectors
|
||||
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
||||
|
||||
return mul_sum_us8_pairs_float(ax, sy);
|
||||
}
|
||||
|
||||
static inline __m128i packNibbles( __m256i bytes ) {
|
||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
|
||||
__m256i high = __lasx_xvandn_v(lowByte, bytes);
|
||||
__m256i low = __lasx_xvand_v(lowByte, bytes);
|
||||
high = __lasx_xvsrli_h(high, 4);
|
||||
bytes = __lasx_xvor_v(low, high);
|
||||
// Compress uint16_t lanes into bytes
|
||||
__m128i *r0 = (__m128i *)&bytes;
|
||||
__m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
|
||||
__m128i *r1 = (__m128i *)&tmp_h128;
|
||||
|
||||
__m128i zero = __lsx_vldi(0);
|
||||
__m128i tmp, tmp2, tmp3;
|
||||
|
||||
tmp = __lsx_vmax_h(zero, *r0);
|
||||
tmp2 = __lsx_vsat_hu(tmp, 7);
|
||||
|
||||
tmp = __lsx_vmax_h(zero, *r1);
|
||||
tmp3 = __lsx_vsat_hu(tmp, 7);
|
||||
return __lsx_vpickev_b(tmp3, tmp2);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __F16C__
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
|
463
ggml-quants.c
463
ggml-quants.c
|
@ -262,421 +262,6 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
|
|||
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
|
||||
typedef union {
|
||||
int32_t i;
|
||||
float f;
|
||||
} ft_union;
|
||||
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(float val) {
|
||||
ft_union fi_tmpval = {.f = val};
|
||||
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
|
||||
static __m256 __lasx_xvreplfr2vr_s(float val) {
|
||||
ft_union fi_tmpval = {.f = val};
|
||||
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
|
||||
#ifdef __clang__
|
||||
#define VREGS_PREFIX "$vr"
|
||||
#define XREGS_PREFIX "$xr"
|
||||
#else // GCC
|
||||
#define VREGS_PREFIX "$f"
|
||||
#define XREGS_PREFIX "$f"
|
||||
#endif
|
||||
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
||||
// Convert __m128i to __m256i
|
||||
static inline __m256i ____m256i(__m128i in) {
|
||||
__m256i out = __lasx_xvldi(0);
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " XREGS_PREFIX"\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
: [out] "+f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
// Convert two __m128i to __m256i
|
||||
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
|
||||
__m256i out;
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".ifnc %[out], %[hi] \n\t"
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvori.b $xr\\i, $xr\\j, 0 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".endif \n\t"
|
||||
: [out] "=f" (out), [hi] "+f" (inhi)
|
||||
: [lo] "f" (inlo)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
// Convert __m256i low part to __m128i
|
||||
static inline __m128i lasx_extracti128_lo(__m256i in) {
|
||||
__m128i out;
|
||||
__asm__ volatile (
|
||||
".ifnc %[out], %[in] \n\t"
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||
" vori.b $vr\\i, $vr\\j, 0 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".endif \n\t"
|
||||
: [out] "=f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
// Convert __m256i high part to __m128i
|
||||
static inline __m128i lasx_extracti128_hi(__m256i in) {
|
||||
__m128i out;
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
: [out] "=f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
|
||||
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
|
||||
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
|
||||
return (__m256i)__ret;
|
||||
}
|
||||
|
||||
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
||||
v4i32 __ret = {d, c, b, a};
|
||||
return (__m128i)__ret;
|
||||
}
|
||||
|
||||
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
|
||||
v4i64 __ret = {d, c, b, a};
|
||||
return (__m256i)__ret;
|
||||
}
|
||||
|
||||
static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
||||
return lasx_set_q(x, y);
|
||||
}
|
||||
#undef MM256_SET_M128I
|
||||
#define MM256_SET_M128I(a, b) lasx_insertf128((a), (b))
|
||||
|
||||
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
||||
__m128i mask_f, zero, tmp0, tmp2, mask;
|
||||
int f = 0x8f;
|
||||
mask_f = __lsx_vreplgr2vr_b(f);
|
||||
zero = __lsx_vldi(0);
|
||||
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
||||
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
||||
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
||||
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
||||
return __lsx_vshuf_b(a, zero, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
||||
__m256i mask_f, zero, tmp0, tmp2, mask;
|
||||
int f = 0x8f;
|
||||
mask_f = __lasx_xvreplgr2vr_b(f);
|
||||
zero = __lasx_xvldi(0);
|
||||
tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
||||
tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
||||
mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
|
||||
tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
|
||||
return __lasx_xvshuf_b(a, zero, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_extu8_16(__m128i a) {
|
||||
__m128i zero = __lsx_vldi(0);
|
||||
__m128i vlo = __lsx_vilvl_b(zero, a);
|
||||
__m128i vhi = __lsx_vilvh_b(zero, a);
|
||||
return lasx_set_q(vhi, vlo);
|
||||
}
|
||||
|
||||
static __m256i lasx_ext8_16(__m128i a) {
|
||||
__m128i sign = __lsx_vslti_b(a, 0);
|
||||
__m128i vlo = __lsx_vilvl_b(sign, a);
|
||||
__m128i vhi = __lsx_vilvh_b(sign, a);
|
||||
return lasx_set_q(vhi, vlo);
|
||||
}
|
||||
|
||||
static __m256i lasx_ext16_32(__m128i a) {
|
||||
__m256i tmp1;
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
|
||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
|
||||
return tmp1;
|
||||
}
|
||||
|
||||
static __m128i lasx_extracti128( __m256i a, int pos) {
|
||||
__m128i ret;
|
||||
if( pos == 0)
|
||||
{
|
||||
ret = lasx_extracti128_lo(a);
|
||||
} else {
|
||||
ret = lasx_extracti128_hi(a);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __m128 lasx_extractf128( __m256 a, int pos) {
|
||||
__m128 ret;
|
||||
if( pos == 0)
|
||||
{
|
||||
ret = (__m128)lasx_extracti128_lo((__m256i)a);
|
||||
} else {
|
||||
ret = (__m128)lasx_extracti128_hi((__m256i)a);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
||||
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
||||
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
||||
return __lsx_vadd_h(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
||||
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
||||
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
||||
return __lsx_vadd_w(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
||||
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
||||
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
||||
|
||||
return __lsx_vfadd_s(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
||||
__m256i tmp1, tmp2;
|
||||
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
||||
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
||||
return __lasx_xvsadd_h(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_madd_h(__m256i a, __m256i b) {
|
||||
__m256i tmp1, tmp2;
|
||||
tmp1 = __lasx_xvmulwev_w_h(a, b);
|
||||
tmp2 = __lasx_xvmulwod_w_h(a, b);
|
||||
return __lasx_xvadd_w(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m256i lasx_packs_w(__m256i a, __m256i b) {
|
||||
__m256i tmp, tmp1;
|
||||
tmp = __lasx_xvsat_w(a, 15);
|
||||
tmp1 = __lasx_xvsat_w(b, 15);
|
||||
return __lasx_xvpickev_h(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m256i lasx_packs_h(__m256i a, __m256i b) {
|
||||
__m256i tmp, tmp1;
|
||||
tmp = __lasx_xvsat_h(a, 7);
|
||||
tmp1 = __lasx_xvsat_h(b, 7);
|
||||
return __lasx_xvpickev_b(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m128i lsx_packs_w(__m128i a, __m128i b) {
|
||||
__m128i tmp, tmp1;
|
||||
tmp = __lsx_vsat_w(a, 15);
|
||||
tmp1 = __lsx_vsat_w(b, 15);
|
||||
return __lsx_vpickev_h(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
||||
__m128i tmp, tmp1;
|
||||
tmp = __lsx_vsat_h(a, 7);
|
||||
tmp1 = __lsx_vsat_h(b, 7);
|
||||
return __lsx_vpickev_b(tmp1, tmp);
|
||||
}
|
||||
|
||||
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
||||
__m128i tmp, tmp1;
|
||||
tmp = __lsx_vsat_hu(a, 7);
|
||||
tmp1 = __lsx_vsat_hu(b, 7);
|
||||
return __lsx_vpickev_b(tmp1, tmp);
|
||||
}
|
||||
|
||||
|
||||
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
|
||||
__m128i tmp1, tmp2;
|
||||
tmp1 = __lsx_vmulwev_h_b(a, b);
|
||||
tmp2 = __lsx_vmulwod_h_b(a, b);
|
||||
return __lsx_vsadd_h(tmp1, tmp2);
|
||||
}
|
||||
|
||||
static __m128i lsx_madd_h(__m128i a, __m128i b) {
|
||||
__m128i tmp1, tmp2;
|
||||
tmp1 = __lsx_vmulwev_w_h(a, b);
|
||||
tmp2 = __lsx_vmulwod_w_h(a, b);
|
||||
return __lsx_vadd_w(tmp1, tmp2);
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice
|
||||
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||
// Get absolute values of x vectors
|
||||
const __m128i ax = __lsx_vsigncov_b(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m128i sy = __lsx_vsigncov_b(x, y);
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m128i dot = lsx_maddubs_h(ax, sy);
|
||||
const __m128i ones = __lsx_vreplgr2vr_h(1);
|
||||
return lsx_madd_h(ones, dot);
|
||||
}
|
||||
|
||||
// horizontally add 8 floats
|
||||
static inline float hsum_float_8(const __m256 x) {
|
||||
__m128 res = lasx_extractf128(x, 1);
|
||||
ft_union tmp;
|
||||
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
||||
tmp.i = __lsx_vpickve2gr_w(res, 0);
|
||||
return tmp.f;
|
||||
}
|
||||
|
||||
// horizontally add 8 int32_t
|
||||
static inline int hsum_i32_8(const __m256i a) {
|
||||
|
||||
__m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
|
||||
__m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
|
||||
|
||||
__m128i tmp1_128 = lasx_extracti128_lo(tmp1);
|
||||
__m128i tmp2_128 = lasx_extracti128_lo(tmp2);
|
||||
|
||||
__m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
|
||||
|
||||
__m128i ev = __lsx_vpickev_w(sum128, sum128);
|
||||
__m128i od = __lsx_vpickod_w(sum128, sum128);
|
||||
__m128i sum64 = __lsx_vadd_w(ev, od);
|
||||
|
||||
int sum64_1, sum64_2;
|
||||
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
||||
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
||||
|
||||
return sum64_1 + sum64_2;
|
||||
}
|
||||
|
||||
// horizontally add 4 int32_t
|
||||
static inline int hsum_i32_4(const __m128i a) {
|
||||
__m128i ev = __lsx_vpickev_w(a, a);
|
||||
__m128i od = __lsx_vpickod_w(a, a);
|
||||
__m128i sum64 = __lsx_vadd_w(ev, od);
|
||||
|
||||
int sum64_1, sum64_2;
|
||||
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
||||
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
||||
|
||||
return sum64_1 + sum64_2;
|
||||
}
|
||||
|
||||
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
||||
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
||||
|
||||
uint32_t x32;
|
||||
memcpy(&x32, x, sizeof(uint32_t));
|
||||
const __m256i shuf_mask = lasx_set_d(
|
||||
0x0303030303030303, 0x0202020202020202,
|
||||
0x0101010101010101, 0x0000000000000000);
|
||||
|
||||
__m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
|
||||
const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
|
||||
bytes = __lasx_xvor_v(bytes, bit_mask);
|
||||
return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
|
||||
}
|
||||
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
||||
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
|
||||
__m128i hi = __lsx_vsrli_h(lo, 4);
|
||||
return __lasx_xvandi_b(MM256_SET_M128I(hi, lo), 0xf);
|
||||
}
|
||||
|
||||
// add int16_t pairwise and return as float vector
|
||||
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
||||
__m256i v = __lasx_xvpackod_h(x, x);
|
||||
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
|
||||
return __lasx_xvffint_s_w(summed_pairs);
|
||||
}
|
||||
|
||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m256i dot = lasx_maddubs_h(ax, sy);
|
||||
return sum_i16_pairs_float(dot);
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice and return as float vector
|
||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||
|
||||
// Get absolute values of x vectors
|
||||
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
||||
// Sign the values of the y vectors
|
||||
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
||||
|
||||
return mul_sum_us8_pairs_float(ax, sy);
|
||||
}
|
||||
|
||||
static inline __m128i packNibbles( __m256i bytes ) {
|
||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
|
||||
__m256i high = __lasx_xvandn_v(lowByte, bytes);
|
||||
__m256i low = __lasx_xvand_v(lowByte, bytes);
|
||||
high = __lasx_xvsrli_h(high, 4);
|
||||
bytes = __lasx_xvor_v(low, high);
|
||||
// Compress uint16_t lanes into bytes
|
||||
__m128i *r0 = (__m128i *)&bytes;
|
||||
__m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
|
||||
__m128i *r1 = (__m128i *)&tmp_h128;
|
||||
|
||||
__m128i zero = __lsx_vldi(0);
|
||||
__m128i tmp, tmp2, tmp3;
|
||||
|
||||
tmp = __lsx_vmax_h(zero, *r0);
|
||||
tmp2 = __lsx_vsat_hu(tmp, 7);
|
||||
|
||||
tmp = __lsx_vmax_h(zero, *r1);
|
||||
tmp3 = __lsx_vsat_hu(tmp, 7);
|
||||
return __lsx_vpickev_b(tmp3, tmp2);
|
||||
}
|
||||
#endif
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||
static const int qk = QK4_0;
|
||||
|
@ -6368,7 +5953,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
const __m256i all_scales = lasx_ext8_16(scales8);
|
||||
const __m128i l_scales = lasx_extracti128(all_scales, 0);
|
||||
const __m128i h_scales = lasx_extracti128(all_scales, 1);
|
||||
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
||||
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
|
||||
|
||||
__m256i sumi = __lasx_xvldi(0);
|
||||
|
||||
|
@ -6790,8 +6375,8 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
summs += dmin * smin;
|
||||
|
||||
const __m128i q2bits = __lsx_vld((const __m128i*)q2, 0);
|
||||
const __m256i q2_0 = __lasx_xvand_v(MM256_SET_M128I(__lsx_vsrli_h(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = __lasx_xvand_v(MM256_SET_M128I(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3);
|
||||
const __m256i q2_0 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 2), q2bits), m3);
|
||||
const __m256i q2_1 = __lasx_xvand_v(lasx_insertf128(__lsx_vsrli_h(q2bits, 6), __lsx_vsrli_h(q2bits, 4)), m3);
|
||||
|
||||
const __m256i q8_0 = __lasx_xvld((const __m256i*)(q8+ 0), 0);
|
||||
const __m256i q8_1 = __lasx_xvld((const __m256i*)(q8+32), 0);
|
||||
|
@ -7491,7 +7076,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
const __m256i all_scales = lasx_ext8_16(scales128);
|
||||
const __m128i l_scales = lasx_extracti128(all_scales, 0);
|
||||
const __m128i h_scales = lasx_extracti128(all_scales, 1);
|
||||
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
||||
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
|
||||
|
||||
// high bit
|
||||
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
|
||||
|
@ -8041,14 +7626,14 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
const uint8_t * restrict q3 = x[i].qs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
const __m256i scale_0 = MM256_SET_M128I(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8));
|
||||
const __m256i scale_1 = MM256_SET_M128I(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8));
|
||||
const __m256i scale_0 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[2] - 8), __lasx_xvreplgr2vr_h(aux8[0] - 8));
|
||||
const __m256i scale_1 = lasx_insertf128(__lasx_xvreplgr2vr_h(aux8[3] - 8), __lasx_xvreplgr2vr_h(aux8[1] - 8));
|
||||
|
||||
memcpy(&aux64, x[i].hmask, 8);
|
||||
|
||||
__m128i haux = __lsx_vinsgr2vr_d(haux, aux64, 0);
|
||||
haux = __lsx_vinsgr2vr_d(haux, aux64 >> 1, 1);
|
||||
__m256i q3h_0 = MM256_SET_M128I(__lsx_vsrli_h(haux, 2), haux);
|
||||
__m256i q3h_0 = lasx_insertf128(__lsx_vsrli_h(haux, 2), haux);
|
||||
__m256i q3h_1 = __lasx_xvsrli_h(q3h_0, 4);
|
||||
q3h_0 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_0, m1), 2);
|
||||
q3h_1 = __lasx_xvslli_h(__lasx_xvandn_v(q3h_1, m1), 2);
|
||||
|
@ -8057,7 +7642,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
const __m128i q3bits = __lsx_vld((const __m128i*)q3, 0);
|
||||
|
||||
// prepare low and high bits
|
||||
const __m256i q3aux = MM256_SET_M128I(__lsx_vsrli_h(q3bits, 2), q3bits);
|
||||
const __m256i q3aux = lasx_insertf128(__lsx_vsrli_h(q3bits, 2), q3bits);
|
||||
const __m256i q3l_0 = __lasx_xvand_v(q3aux, m3);
|
||||
const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3aux, 4), m3);
|
||||
|
||||
|
@ -8602,7 +8187,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
|
||||
|
||||
const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
|
||||
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
||||
const __m256i scales = lasx_insertf128(sc128, sc128);
|
||||
|
||||
__m256i sumi = __lasx_xvldi(0);
|
||||
|
||||
|
@ -9589,7 +9174,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
|
||||
|
||||
const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
|
||||
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
||||
const __m256i scales = lasx_insertf128(sc128, sc128);
|
||||
|
||||
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
|
||||
__m256i hmask = mone;
|
||||
|
@ -10023,14 +9608,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0);
|
||||
|
||||
const __m256i scale_l = MM256_SET_M128I(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0]));
|
||||
const __m256i scale_h = MM256_SET_M128I(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2]));
|
||||
const __m256i scale_l = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[1]), __lsx_vreplgr2vr_h(x[i].scales[0]));
|
||||
const __m256i scale_h = lasx_insertf128(__lsx_vreplgr2vr_h(x[i].scales[3]), __lsx_vreplgr2vr_h(x[i].scales[2]));
|
||||
|
||||
int64_t aux64;
|
||||
memcpy(&aux64, x[i].qh, 8);
|
||||
__m128i haux128 = __lsx_vinsgr2vr_d(haux128, aux64, 0);
|
||||
haux128 = __lsx_vinsgr2vr_d(haux128, aux64 >> 1, 1);
|
||||
const __m256i haux256 = MM256_SET_M128I(__lsx_vsrli_h(haux128, 2), haux128);
|
||||
const __m256i haux256 = lasx_insertf128(__lsx_vsrli_h(haux128, 2), haux128);
|
||||
|
||||
const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvandn_v(haux256, mone), 4);
|
||||
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvandn_v(__lasx_xvsrli_h(haux256, 4), mone), 4);
|
||||
|
@ -11122,8 +10707,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||
const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0);
|
||||
const __m128i q4bitsH = __lsx_vld((const __m128i*)qh, 0);
|
||||
|
||||
const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(MM256_SET_M128I(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4);
|
||||
const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(MM256_SET_M128I(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4);
|
||||
const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 2), q4bitsH), m2), 4);
|
||||
const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(lasx_insertf128(__lasx_xvsrli_h(q4bitsH, 6), __lasx_xvsrli_h(q4bitsH, 4)), m2), 4);
|
||||
|
||||
const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
|
||||
const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_1);
|
||||
|
@ -11782,7 +11367,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||
|
||||
const __m128i odd_bits = lsx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
|
||||
const __m128i full_sign_bits = __lsx_vor_v(partial_sign_bits, odd_bits);
|
||||
const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits);
|
||||
const __m256i full_signs = lasx_insertf128(full_sign_bits, full_sign_bits);
|
||||
|
||||
const __m256i q8_1 = __lasx_xvld((const __m256i *)y[i].qs, 0);
|
||||
const __m256i q8_2 = __lasx_xvld((const __m256i *)(y[i].qs+32), 0);
|
||||
|
@ -11803,8 +11388,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||
const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
|
||||
const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
|
||||
|
||||
const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
|
||||
const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
|
||||
const __m256i sc1 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[0] & 0xf)+1));
|
||||
const __m256i sc2 = lasx_insertf128(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), __lsx_vreplgr2vr_h(2*(x[i].scales[1] & 0xf)+1));
|
||||
|
||||
const __m256i sum = __lasx_xvadd_w(lasx_madd_h(sc1, dot1), lasx_madd_h(sc2, dot2));
|
||||
|
||||
|
@ -11870,8 +11455,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||
|
||||
const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
|
||||
const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
|
||||
const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
|
||||
const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
|
||||
const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
|
||||
const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
|
||||
|
||||
__m256i signs;
|
||||
signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
|
||||
|
@ -13862,9 +13447,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0);
|
||||
const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0);
|
||||
const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0);
|
||||
const __m256i q4b_1 = MM256_SET_M128I(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
|
||||
const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
|
||||
lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
|
||||
const __m256i q4b_2 = MM256_SET_M128I(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
|
||||
const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
|
||||
lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
|
||||
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
||||
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
||||
|
@ -14126,7 +13711,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||
tmp4 = __lsx_vand_v(tmp0, mask);
|
||||
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
|
||||
|
||||
const __m256i q4b_1 = MM256_SET_M128I(tmp3, tmp4);
|
||||
const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
|
||||
|
||||
tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
|
||||
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
||||
|
@ -14140,7 +13725,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||
tmp4 = __lsx_vand_v(tmp0, mask);
|
||||
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
|
||||
|
||||
const __m256i q4b_2 = MM256_SET_M128I(tmp3, tmp4);
|
||||
const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
|
||||
|
||||
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
||||
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
||||
|
|
27
ggml.c
27
ggml.c
|
@ -1528,24 +1528,6 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
|||
#define GGML_SIMD
|
||||
|
||||
// F32 LASX
|
||||
|
||||
typedef union
|
||||
{
|
||||
int32_t i;
|
||||
float f;
|
||||
} FloatInt;
|
||||
/* float type data load instructions */
|
||||
static __m128 __lsx_vreplfr2vr_s(float val)
|
||||
{
|
||||
FloatInt fi_tmpval = {.f = val};
|
||||
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
|
||||
static __m256 __lasx_xvreplfr2vr_s(float val)
|
||||
{
|
||||
FloatInt fi_tmpval = {.f = val};
|
||||
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
||||
}
|
||||
#define GGML_F32_STEP 32
|
||||
#define GGML_F32_EPR 8
|
||||
|
||||
|
@ -1597,7 +1579,7 @@ do { \
|
|||
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
||||
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
||||
|
||||
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
||||
static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
|
||||
float tmp[8];
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
|
@ -1606,7 +1588,7 @@ static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
|||
|
||||
return (__m256)__lasx_xvld(tmp, 0);
|
||||
}
|
||||
static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
||||
static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
||||
float arr[8];
|
||||
|
||||
__lasx_xvst(y, arr, 0);
|
||||
|
@ -1614,8 +1596,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|||
for (int i = 0; i < 8; i++)
|
||||
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
||||
}
|
||||
#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
|
||||
#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
|
||||
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
|
||||
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
|
||||
|
||||
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
|
||||
#define GGML_F32Cx8_ADD __lasx_xvfadd_s
|
||||
|
@ -1632,7 +1614,6 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|||
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
|
||||
|
||||
|
||||
#elif defined(__loongarch_sx)
|
||||
|
||||
#define GGML_SIMD
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue