update 2
This commit is contained in:
parent
8a0d9a304f
commit
ee26b8ff10
2 changed files with 397 additions and 394 deletions
394
ggml-impl.h
394
ggml-impl.h
|
@ -469,400 +469,6 @@ static __m256 __lasx_xvreplfr2vr_s(float val) {
|
||||||
ft_union fi_tmpval = {.f = val};
|
ft_union fi_tmpval = {.f = val};
|
||||||
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __clang__
|
|
||||||
#define VREGS_PREFIX "$vr"
|
|
||||||
#define XREGS_PREFIX "$xr"
|
|
||||||
#else // GCC
|
|
||||||
#define VREGS_PREFIX "$f"
|
|
||||||
#define XREGS_PREFIX "$f"
|
|
||||||
#endif
|
|
||||||
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
|
||||||
// Convert __m128i to __m256i
|
|
||||||
static inline __m256i ____m256i(__m128i in) {
|
|
||||||
__m256i out = __lasx_xvldi(0);
|
|
||||||
__asm__ volatile (
|
|
||||||
".irp i," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[out], " XREGS_PREFIX"\\i \n\t"
|
|
||||||
" .irp j," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[in], " VREGS_PREFIX "\\j \n\t"
|
|
||||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
" .endr \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
".endr \n\t"
|
|
||||||
: [out] "+f" (out) : [in] "f" (in)
|
|
||||||
);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
// Convert two __m128i to __m256i
|
|
||||||
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
|
|
||||||
__m256i out;
|
|
||||||
__asm__ volatile (
|
|
||||||
".irp i," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
|
|
||||||
" .irp j," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
|
|
||||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
" .endr \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
".endr \n\t"
|
|
||||||
".ifnc %[out], %[hi] \n\t"
|
|
||||||
".irp i," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
|
|
||||||
" .irp j," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
|
|
||||||
" xvori.b $xr\\i, $xr\\j, 0 \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
" .endr \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
".endr \n\t"
|
|
||||||
".endif \n\t"
|
|
||||||
: [out] "=f" (out), [hi] "+f" (inhi)
|
|
||||||
: [lo] "f" (inlo)
|
|
||||||
);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
// Convert __m256i low part to __m128i
|
|
||||||
static inline __m128i lasx_extracti128_lo(__m256i in) {
|
|
||||||
__m128i out;
|
|
||||||
__asm__ volatile (
|
|
||||||
".ifnc %[out], %[in] \n\t"
|
|
||||||
".irp i," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
|
||||||
" .irp j," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
|
||||||
" vori.b $vr\\i, $vr\\j, 0 \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
" .endr \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
".endr \n\t"
|
|
||||||
".endif \n\t"
|
|
||||||
: [out] "=f" (out) : [in] "f" (in)
|
|
||||||
);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
// Convert __m256i high part to __m128i
|
|
||||||
static inline __m128i lasx_extracti128_hi(__m256i in) {
|
|
||||||
__m128i out;
|
|
||||||
__asm__ volatile (
|
|
||||||
".irp i," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
|
||||||
" .irp j," __ALL_REGS "\n\t"
|
|
||||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
|
||||||
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
" .endr \n\t"
|
|
||||||
" .endif \n\t"
|
|
||||||
".endr \n\t"
|
|
||||||
: [out] "=f" (out) : [in] "f" (in)
|
|
||||||
);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
|
|
||||||
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
|
|
||||||
return (__m256i)__ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
|
||||||
v4i32 __ret = {d, c, b, a};
|
|
||||||
return (__m128i)__ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
|
|
||||||
v4i64 __ret = {d, c, b, a};
|
|
||||||
return (__m256i)__ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
|
||||||
return lasx_set_q(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
|
||||||
__m128i mask_f, zero, tmp0, tmp2, mask;
|
|
||||||
int f = 0x8f;
|
|
||||||
mask_f = __lsx_vreplgr2vr_b(f);
|
|
||||||
zero = __lsx_vldi(0);
|
|
||||||
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
|
||||||
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
|
||||||
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
|
||||||
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
|
||||||
return __lsx_vshuf_b(a, zero, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
|
||||||
__m256i mask_f, zero, tmp0, tmp2, mask;
|
|
||||||
int f = 0x8f;
|
|
||||||
mask_f = __lasx_xvreplgr2vr_b(f);
|
|
||||||
zero = __lasx_xvldi(0);
|
|
||||||
tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
|
||||||
tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
|
||||||
mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
|
|
||||||
tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
|
|
||||||
return __lasx_xvshuf_b(a, zero, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_extu8_16(__m128i a) {
|
|
||||||
__m128i zero = __lsx_vldi(0);
|
|
||||||
__m128i vlo = __lsx_vilvl_b(zero, a);
|
|
||||||
__m128i vhi = __lsx_vilvh_b(zero, a);
|
|
||||||
return lasx_set_q(vhi, vlo);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_ext8_16(__m128i a) {
|
|
||||||
__m128i sign = __lsx_vslti_b(a, 0);
|
|
||||||
__m128i vlo = __lsx_vilvl_b(sign, a);
|
|
||||||
__m128i vhi = __lsx_vilvh_b(sign, a);
|
|
||||||
return lasx_set_q(vhi, vlo);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_ext16_32(__m128i a) {
|
|
||||||
__m256i tmp1;
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
|
|
||||||
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
|
|
||||||
return tmp1;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lasx_extracti128( __m256i a, int pos) {
|
|
||||||
__m128i ret;
|
|
||||||
if( pos == 0)
|
|
||||||
{
|
|
||||||
ret = lasx_extracti128_lo(a);
|
|
||||||
} else {
|
|
||||||
ret = lasx_extracti128_hi(a);
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128 lasx_extractf128( __m256 a, int pos) {
|
|
||||||
__m128 ret;
|
|
||||||
if( pos == 0)
|
|
||||||
{
|
|
||||||
ret = (__m128)lasx_extracti128_lo((__m256i)a);
|
|
||||||
} else {
|
|
||||||
ret = (__m128)lasx_extracti128_hi((__m256i)a);
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
|
||||||
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
|
||||||
return __lsx_vadd_h(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
|
||||||
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
|
||||||
return __lsx_vadd_w(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
|
||||||
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
|
||||||
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
|
||||||
|
|
||||||
return __lsx_vfadd_s(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
|
||||||
__m256i tmp1, tmp2;
|
|
||||||
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
|
||||||
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
|
||||||
return __lasx_xvsadd_h(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_madd_h(__m256i a, __m256i b) {
|
|
||||||
__m256i tmp1, tmp2;
|
|
||||||
tmp1 = __lasx_xvmulwev_w_h(a, b);
|
|
||||||
tmp2 = __lasx_xvmulwod_w_h(a, b);
|
|
||||||
return __lasx_xvadd_w(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_packs_w(__m256i a, __m256i b) {
|
|
||||||
__m256i tmp, tmp1;
|
|
||||||
tmp = __lasx_xvsat_w(a, 15);
|
|
||||||
tmp1 = __lasx_xvsat_w(b, 15);
|
|
||||||
return __lasx_xvpickev_h(tmp1, tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m256i lasx_packs_h(__m256i a, __m256i b) {
|
|
||||||
__m256i tmp, tmp1;
|
|
||||||
tmp = __lasx_xvsat_h(a, 7);
|
|
||||||
tmp1 = __lasx_xvsat_h(b, 7);
|
|
||||||
return __lasx_xvpickev_b(tmp1, tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_packs_w(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp, tmp1;
|
|
||||||
tmp = __lsx_vsat_w(a, 15);
|
|
||||||
tmp1 = __lsx_vsat_w(b, 15);
|
|
||||||
return __lsx_vpickev_h(tmp1, tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp, tmp1;
|
|
||||||
tmp = __lsx_vsat_h(a, 7);
|
|
||||||
tmp1 = __lsx_vsat_h(b, 7);
|
|
||||||
return __lsx_vpickev_b(tmp1, tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp, tmp1;
|
|
||||||
tmp = __lsx_vsat_hu(a, 7);
|
|
||||||
tmp1 = __lsx_vsat_hu(b, 7);
|
|
||||||
return __lsx_vpickev_b(tmp1, tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp1, tmp2;
|
|
||||||
tmp1 = __lsx_vmulwev_h_b(a, b);
|
|
||||||
tmp2 = __lsx_vmulwod_h_b(a, b);
|
|
||||||
return __lsx_vsadd_h(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static __m128i lsx_madd_h(__m128i a, __m128i b) {
|
|
||||||
__m128i tmp1, tmp2;
|
|
||||||
tmp1 = __lsx_vmulwev_w_h(a, b);
|
|
||||||
tmp2 = __lsx_vmulwod_w_h(a, b);
|
|
||||||
return __lsx_vadd_w(tmp1, tmp2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// multiply int8_t, add results pairwise twice
|
|
||||||
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
|
||||||
// Get absolute values of x vectors
|
|
||||||
const __m128i ax = __lsx_vsigncov_b(x, x);
|
|
||||||
// Sign the values of the y vectors
|
|
||||||
const __m128i sy = __lsx_vsigncov_b(x, y);
|
|
||||||
// Perform multiplication and create 16-bit values
|
|
||||||
const __m128i dot = lsx_maddubs_h(ax, sy);
|
|
||||||
const __m128i ones = __lsx_vreplgr2vr_h(1);
|
|
||||||
return lsx_madd_h(ones, dot);
|
|
||||||
}
|
|
||||||
|
|
||||||
// horizontally add 8 floats
|
|
||||||
static inline float hsum_float_8(const __m256 x) {
|
|
||||||
__m128 res = lasx_extractf128(x, 1);
|
|
||||||
ft_union tmp;
|
|
||||||
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
|
||||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
|
||||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
|
||||||
tmp.i = __lsx_vpickve2gr_w(res, 0);
|
|
||||||
return tmp.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// horizontally add 8 int32_t
|
|
||||||
static inline int hsum_i32_8(const __m256i a) {
|
|
||||||
|
|
||||||
__m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
|
|
||||||
__m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
|
|
||||||
|
|
||||||
__m128i tmp1_128 = lasx_extracti128_lo(tmp1);
|
|
||||||
__m128i tmp2_128 = lasx_extracti128_lo(tmp2);
|
|
||||||
|
|
||||||
__m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
|
|
||||||
|
|
||||||
__m128i ev = __lsx_vpickev_w(sum128, sum128);
|
|
||||||
__m128i od = __lsx_vpickod_w(sum128, sum128);
|
|
||||||
__m128i sum64 = __lsx_vadd_w(ev, od);
|
|
||||||
|
|
||||||
int sum64_1, sum64_2;
|
|
||||||
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
|
||||||
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
|
||||||
|
|
||||||
return sum64_1 + sum64_2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// horizontally add 4 int32_t
|
|
||||||
static inline int hsum_i32_4(const __m128i a) {
|
|
||||||
__m128i ev = __lsx_vpickev_w(a, a);
|
|
||||||
__m128i od = __lsx_vpickod_w(a, a);
|
|
||||||
__m128i sum64 = __lsx_vadd_w(ev, od);
|
|
||||||
|
|
||||||
int sum64_1, sum64_2;
|
|
||||||
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
|
||||||
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
|
||||||
|
|
||||||
return sum64_1 + sum64_2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
|
||||||
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
|
||||||
|
|
||||||
uint32_t x32;
|
|
||||||
memcpy(&x32, x, sizeof(uint32_t));
|
|
||||||
const __m256i shuf_mask = lasx_set_d(
|
|
||||||
0x0303030303030303, 0x0202020202020202,
|
|
||||||
0x0101010101010101, 0x0000000000000000);
|
|
||||||
|
|
||||||
__m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
|
|
||||||
const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
|
|
||||||
bytes = __lasx_xvor_v(bytes, bit_mask);
|
|
||||||
return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unpack 32 4-bit fields into 32 bytes
|
|
||||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
|
||||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
|
||||||
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
|
|
||||||
__m128i hi = __lsx_vsrli_h(lo, 4);
|
|
||||||
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
|
|
||||||
}
|
|
||||||
|
|
||||||
// add int16_t pairwise and return as float vector
|
|
||||||
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|
||||||
__m256i v = __lasx_xvpackod_h(x, x);
|
|
||||||
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
|
|
||||||
return __lasx_xvffint_s_w(summed_pairs);
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
|
||||||
// Perform multiplication and create 16-bit values
|
|
||||||
const __m256i dot = lasx_maddubs_h(ax, sy);
|
|
||||||
return sum_i16_pairs_float(dot);
|
|
||||||
}
|
|
||||||
|
|
||||||
// multiply int8_t, add results pairwise twice and return as float vector
|
|
||||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
|
||||||
|
|
||||||
// Get absolute values of x vectors
|
|
||||||
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
|
||||||
// Sign the values of the y vectors
|
|
||||||
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
|
||||||
|
|
||||||
return mul_sum_us8_pairs_float(ax, sy);
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline __m128i packNibbles( __m256i bytes ) {
|
|
||||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
|
||||||
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
|
|
||||||
__m256i high = __lasx_xvandn_v(lowByte, bytes);
|
|
||||||
__m256i low = __lasx_xvand_v(lowByte, bytes);
|
|
||||||
high = __lasx_xvsrli_h(high, 4);
|
|
||||||
bytes = __lasx_xvor_v(low, high);
|
|
||||||
// Compress uint16_t lanes into bytes
|
|
||||||
__m128i *r0 = (__m128i *)&bytes;
|
|
||||||
__m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
|
|
||||||
__m128i *r1 = (__m128i *)&tmp_h128;
|
|
||||||
|
|
||||||
__m128i zero = __lsx_vldi(0);
|
|
||||||
__m128i tmp, tmp2, tmp3;
|
|
||||||
|
|
||||||
tmp = __lsx_vmax_h(zero, *r0);
|
|
||||||
tmp2 = __lsx_vsat_hu(tmp, 7);
|
|
||||||
|
|
||||||
tmp = __lsx_vmax_h(zero, *r1);
|
|
||||||
tmp3 = __lsx_vsat_hu(tmp, 7);
|
|
||||||
return __lsx_vpickev_b(tmp3, tmp2);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __F16C__
|
#ifdef __F16C__
|
||||||
|
|
397
ggml-quants.c
397
ggml-quants.c
|
@ -262,6 +262,403 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
|
||||||
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
|
|
||||||
|
#ifdef __clang__
|
||||||
|
#define VREGS_PREFIX "$vr"
|
||||||
|
#define XREGS_PREFIX "$xr"
|
||||||
|
#else // GCC
|
||||||
|
#define VREGS_PREFIX "$f"
|
||||||
|
#define XREGS_PREFIX "$f"
|
||||||
|
#endif
|
||||||
|
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
||||||
|
// Convert __m128i to __m256i
|
||||||
|
static inline __m256i ____m256i(__m128i in) {
|
||||||
|
__m256i out = __lasx_xvldi(0);
|
||||||
|
__asm__ volatile (
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " XREGS_PREFIX"\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[in], " VREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
: [out] "+f" (out) : [in] "f" (in)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
// Convert two __m128i to __m256i
|
||||||
|
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
|
||||||
|
__m256i out;
|
||||||
|
__asm__ volatile (
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
".ifnc %[out], %[hi] \n\t"
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvori.b $xr\\i, $xr\\j, 0 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
".endif \n\t"
|
||||||
|
: [out] "=f" (out), [hi] "+f" (inhi)
|
||||||
|
: [lo] "f" (inlo)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
// Convert __m256i low part to __m128i
|
||||||
|
static inline __m128i lasx_extracti128_lo(__m256i in) {
|
||||||
|
__m128i out;
|
||||||
|
__asm__ volatile (
|
||||||
|
".ifnc %[out], %[in] \n\t"
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||||
|
" vori.b $vr\\i, $vr\\j, 0 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
".endif \n\t"
|
||||||
|
: [out] "=f" (out) : [in] "f" (in)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
// Convert __m256i high part to __m128i
|
||||||
|
static inline __m128i lasx_extracti128_hi(__m256i in) {
|
||||||
|
__m128i out;
|
||||||
|
__asm__ volatile (
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
: [out] "=f" (out) : [in] "f" (in)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
|
||||||
|
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
|
||||||
|
return (__m256i)__ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
||||||
|
v4i32 __ret = {d, c, b, a};
|
||||||
|
return (__m128i)__ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
|
||||||
|
v4i64 __ret = {d, c, b, a};
|
||||||
|
return (__m256i)__ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
||||||
|
return lasx_set_q(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
||||||
|
__m128i mask_f, zero, tmp0, tmp2, mask;
|
||||||
|
int f = 0x8f;
|
||||||
|
mask_f = __lsx_vreplgr2vr_b(f);
|
||||||
|
zero = __lsx_vldi(0);
|
||||||
|
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
||||||
|
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
||||||
|
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
||||||
|
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
||||||
|
return __lsx_vshuf_b(a, zero, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
||||||
|
__m256i mask_f, zero, tmp0, tmp2, mask;
|
||||||
|
int f = 0x8f;
|
||||||
|
mask_f = __lasx_xvreplgr2vr_b(f);
|
||||||
|
zero = __lasx_xvldi(0);
|
||||||
|
tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
||||||
|
tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
||||||
|
mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
|
||||||
|
tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
|
||||||
|
return __lasx_xvshuf_b(a, zero, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_extu8_16(__m128i a) {
|
||||||
|
__m128i zero = __lsx_vldi(0);
|
||||||
|
__m128i vlo = __lsx_vilvl_b(zero, a);
|
||||||
|
__m128i vhi = __lsx_vilvh_b(zero, a);
|
||||||
|
return lasx_set_q(vhi, vlo);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_ext8_16(__m128i a) {
|
||||||
|
__m128i sign = __lsx_vslti_b(a, 0);
|
||||||
|
__m128i vlo = __lsx_vilvl_b(sign, a);
|
||||||
|
__m128i vhi = __lsx_vilvh_b(sign, a);
|
||||||
|
return lasx_set_q(vhi, vlo);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_ext16_32(__m128i a) {
|
||||||
|
__m256i tmp1;
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
|
||||||
|
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
|
||||||
|
return tmp1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lasx_extracti128( __m256i a, int pos) {
|
||||||
|
__m128i ret;
|
||||||
|
if( pos == 0)
|
||||||
|
{
|
||||||
|
ret = lasx_extracti128_lo(a);
|
||||||
|
} else {
|
||||||
|
ret = lasx_extracti128_hi(a);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128 lasx_extractf128( __m256 a, int pos) {
|
||||||
|
__m128 ret;
|
||||||
|
if( pos == 0)
|
||||||
|
{
|
||||||
|
ret = (__m128)lasx_extracti128_lo((__m256i)a);
|
||||||
|
} else {
|
||||||
|
ret = (__m128)lasx_extracti128_hi((__m256i)a);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
||||||
|
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
||||||
|
return __lsx_vadd_h(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
||||||
|
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
||||||
|
return __lsx_vadd_w(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
||||||
|
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
||||||
|
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
||||||
|
|
||||||
|
return __lsx_vfadd_s(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
||||||
|
__m256i tmp1, tmp2;
|
||||||
|
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
||||||
|
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
||||||
|
return __lasx_xvsadd_h(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_madd_h(__m256i a, __m256i b) {
|
||||||
|
__m256i tmp1, tmp2;
|
||||||
|
tmp1 = __lasx_xvmulwev_w_h(a, b);
|
||||||
|
tmp2 = __lasx_xvmulwod_w_h(a, b);
|
||||||
|
return __lasx_xvadd_w(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_packs_w(__m256i a, __m256i b) {
|
||||||
|
__m256i tmp, tmp1;
|
||||||
|
tmp = __lasx_xvsat_w(a, 15);
|
||||||
|
tmp1 = __lasx_xvsat_w(b, 15);
|
||||||
|
return __lasx_xvpickev_h(tmp1, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_packs_h(__m256i a, __m256i b) {
|
||||||
|
__m256i tmp, tmp1;
|
||||||
|
tmp = __lasx_xvsat_h(a, 7);
|
||||||
|
tmp1 = __lasx_xvsat_h(b, 7);
|
||||||
|
return __lasx_xvpickev_b(tmp1, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_packs_w(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp, tmp1;
|
||||||
|
tmp = __lsx_vsat_w(a, 15);
|
||||||
|
tmp1 = __lsx_vsat_w(b, 15);
|
||||||
|
return __lsx_vpickev_h(tmp1, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp, tmp1;
|
||||||
|
tmp = __lsx_vsat_h(a, 7);
|
||||||
|
tmp1 = __lsx_vsat_h(b, 7);
|
||||||
|
return __lsx_vpickev_b(tmp1, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp, tmp1;
|
||||||
|
tmp = __lsx_vsat_hu(a, 7);
|
||||||
|
tmp1 = __lsx_vsat_hu(b, 7);
|
||||||
|
return __lsx_vpickev_b(tmp1, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp1, tmp2;
|
||||||
|
tmp1 = __lsx_vmulwev_h_b(a, b);
|
||||||
|
tmp2 = __lsx_vmulwod_h_b(a, b);
|
||||||
|
return __lsx_vsadd_h(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128i lsx_madd_h(__m128i a, __m128i b) {
|
||||||
|
__m128i tmp1, tmp2;
|
||||||
|
tmp1 = __lsx_vmulwev_w_h(a, b);
|
||||||
|
tmp2 = __lsx_vmulwod_w_h(a, b);
|
||||||
|
return __lsx_vadd_w(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// multiply int8_t, add results pairwise twice
|
||||||
|
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||||
|
// Get absolute values of x vectors
|
||||||
|
const __m128i ax = __lsx_vsigncov_b(x, x);
|
||||||
|
// Sign the values of the y vectors
|
||||||
|
const __m128i sy = __lsx_vsigncov_b(x, y);
|
||||||
|
// Perform multiplication and create 16-bit values
|
||||||
|
const __m128i dot = lsx_maddubs_h(ax, sy);
|
||||||
|
const __m128i ones = __lsx_vreplgr2vr_h(1);
|
||||||
|
return lsx_madd_h(ones, dot);
|
||||||
|
}
|
||||||
|
|
||||||
|
// horizontally add 8 floats
|
||||||
|
static inline float hsum_float_8(const __m256 x) {
|
||||||
|
__m128 res = lasx_extractf128(x, 1);
|
||||||
|
ft_union tmp;
|
||||||
|
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
||||||
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
||||||
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
||||||
|
tmp.i = __lsx_vpickve2gr_w(res, 0);
|
||||||
|
return tmp.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// horizontally add 8 int32_t
|
||||||
|
static inline int hsum_i32_8(const __m256i a) {
|
||||||
|
|
||||||
|
__m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
|
||||||
|
__m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
|
||||||
|
|
||||||
|
__m128i tmp1_128 = lasx_extracti128_lo(tmp1);
|
||||||
|
__m128i tmp2_128 = lasx_extracti128_lo(tmp2);
|
||||||
|
|
||||||
|
__m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
|
||||||
|
|
||||||
|
__m128i ev = __lsx_vpickev_w(sum128, sum128);
|
||||||
|
__m128i od = __lsx_vpickod_w(sum128, sum128);
|
||||||
|
__m128i sum64 = __lsx_vadd_w(ev, od);
|
||||||
|
|
||||||
|
int sum64_1, sum64_2;
|
||||||
|
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
||||||
|
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
||||||
|
|
||||||
|
return sum64_1 + sum64_2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// horizontally add 4 int32_t
|
||||||
|
static inline int hsum_i32_4(const __m128i a) {
|
||||||
|
__m128i ev = __lsx_vpickev_w(a, a);
|
||||||
|
__m128i od = __lsx_vpickod_w(a, a);
|
||||||
|
__m128i sum64 = __lsx_vadd_w(ev, od);
|
||||||
|
|
||||||
|
int sum64_1, sum64_2;
|
||||||
|
sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
|
||||||
|
sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
|
||||||
|
|
||||||
|
return sum64_1 + sum64_2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
||||||
|
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
||||||
|
|
||||||
|
uint32_t x32;
|
||||||
|
memcpy(&x32, x, sizeof(uint32_t));
|
||||||
|
const __m256i shuf_mask = lasx_set_d(
|
||||||
|
0x0303030303030303, 0x0202020202020202,
|
||||||
|
0x0101010101010101, 0x0000000000000000);
|
||||||
|
|
||||||
|
__m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
|
||||||
|
const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
|
||||||
|
bytes = __lasx_xvor_v(bytes, bit_mask);
|
||||||
|
return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unpack 32 4-bit fields into 32 bytes
|
||||||
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||||
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
||||||
|
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
|
||||||
|
__m128i hi = __lsx_vsrli_h(lo, 4);
|
||||||
|
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add int16_t pairwise and return as float vector
|
||||||
|
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
||||||
|
__m256i v = __lasx_xvpackod_h(x, x);
|
||||||
|
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
|
||||||
|
return __lasx_xvffint_s_w(summed_pairs);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||||
|
// Perform multiplication and create 16-bit values
|
||||||
|
const __m256i dot = lasx_maddubs_h(ax, sy);
|
||||||
|
return sum_i16_pairs_float(dot);
|
||||||
|
}
|
||||||
|
|
||||||
|
// multiply int8_t, add results pairwise twice and return as float vector
|
||||||
|
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||||
|
|
||||||
|
// Get absolute values of x vectors
|
||||||
|
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
||||||
|
// Sign the values of the y vectors
|
||||||
|
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
||||||
|
|
||||||
|
return mul_sum_us8_pairs_float(ax, sy);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m128i packNibbles( __m256i bytes ) {
|
||||||
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||||
|
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
|
||||||
|
__m256i high = __lasx_xvandn_v(lowByte, bytes);
|
||||||
|
__m256i low = __lasx_xvand_v(lowByte, bytes);
|
||||||
|
high = __lasx_xvsrli_h(high, 4);
|
||||||
|
bytes = __lasx_xvor_v(low, high);
|
||||||
|
// Compress uint16_t lanes into bytes
|
||||||
|
__m128i *r0 = (__m128i *)&bytes;
|
||||||
|
__m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
|
||||||
|
__m128i *r1 = (__m128i *)&tmp_h128;
|
||||||
|
|
||||||
|
__m128i zero = __lsx_vldi(0);
|
||||||
|
__m128i tmp, tmp2, tmp3;
|
||||||
|
|
||||||
|
tmp = __lsx_vmax_h(zero, *r0);
|
||||||
|
tmp2 = __lsx_vsat_hu(tmp, 7);
|
||||||
|
|
||||||
|
tmp = __lsx_vmax_h(zero, *r1);
|
||||||
|
tmp3 = __lsx_vsat_hu(tmp, 7);
|
||||||
|
return __lsx_vpickev_b(tmp3, tmp2);
|
||||||
|
}
|
||||||
|
#endif //__loongarch_asx
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||||
static const int qk = QK4_0;
|
static const int qk = QK4_0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue