format code

This commit is contained in:
junchao-loongson 2024-05-18 15:35:51 +08:00
parent fdef7620be
commit 3b6199ba3c

View file

@ -264,21 +264,19 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#if defined(__loongarch_asx)
typedef union
{
typedef union {
int32_t i;
float f;
} FloatInt;
} ft_union;
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(float val)
{
FloatInt fi_tmpval = {.f = val};
static __m128 __lsx_vreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
}
static __m256 __lasx_xvreplfr2vr_s(float val)
{
FloatInt fi_tmpval = {.f = val};
static __m256 __lasx_xvreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
}
@ -291,8 +289,7 @@ static __m256 __lasx_xvreplfr2vr_s(float val)
#endif
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
// Convert __m128i to __m256i
static inline __m256i ____m256i(__m128i in)
{
static inline __m256i ____m256i(__m128i in) {
__m256i out = __lasx_xvldi(0);
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
@ -309,8 +306,7 @@ static inline __m256i ____m256i(__m128i in)
return out;
}
// Convert two __m128i to __m256i
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo)
{
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
__m256i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
@ -339,8 +335,7 @@ static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo)
return out;
}
// Convert __m256i low part to __m128i
static inline __m128i lasx_extracti128_lo(__m256i in)
{
static inline __m128i lasx_extracti128_lo(__m256i in) {
__m128i out;
__asm__ volatile (
".ifnc %[out], %[in] \n\t"
@ -359,8 +354,7 @@ static inline __m128i lasx_extracti128_lo(__m256i in)
return out;
}
// Convert __m256i high part to __m128i
static inline __m128i lasx_extracti128_hi(__m256i in)
{
static inline __m128i lasx_extracti128_hi(__m256i in) {
__m128i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
@ -377,32 +371,28 @@ static inline __m128i lasx_extracti128_hi(__m256i in)
return out;
}
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0)
{
static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
return (__m256i)__ret;
}
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d)
{
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d)
{
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
v4i64 __ret = {d, c, b, a};
return (__m256i)__ret;
}
static __m256i lasx_insertf128( __m128i x, __m128i y)
{
static __m256i lasx_insertf128( __m128i x, __m128i y) {
return lasx_set_q(x, y);
}
#undef MM256_SET_M128I
#define MM256_SET_M128I(a, b) lasx_insertf128((a), (b))
static __m128i lsx_shuffle_b(__m128i a, __m128i b)
{
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
@ -414,8 +404,7 @@ static __m128i lsx_shuffle_b(__m128i a, __m128i b)
return __lsx_vshuf_b(a, zero, tmp2);
}
static __m256i lasx_shuffle_b(__m256i a, __m256i b)
{
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
__m256i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lasx_xvreplgr2vr_b(f);
@ -427,24 +416,21 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b)
return __lasx_xvshuf_b(a, zero, tmp2);
}
static __m256i lasx_extu8_16(__m128i a)
{
static __m256i lasx_extu8_16(__m128i a) {
__m128i zero = __lsx_vldi(0);
__m128i vlo = __lsx_vilvl_b(zero, a);
__m128i vhi = __lsx_vilvh_b(zero, a);
return lasx_set_q(vhi, vlo);
}
static __m256i lasx_ext8_16(__m128i a)
{
static __m256i lasx_ext8_16(__m128i a) {
__m128i sign = __lsx_vslti_b(a, 0);
__m128i vlo = __lsx_vilvl_b(sign, a);
__m128i vhi = __lsx_vilvh_b(sign, a);
return lasx_set_q(vhi, vlo);
}
static __m256i lasx_ext16_32(__m128i a)
{
static __m256i lasx_ext16_32(__m128i a) {
__m256i tmp1;
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
@ -457,8 +443,7 @@ static __m256i lasx_ext16_32(__m128i a)
return tmp1;
}
static __m128i lasx_extracti128( __m256i a, int pos)
{
static __m128i lasx_extracti128( __m256i a, int pos) {
__m128i ret;
if( pos == 0)
{
@ -469,8 +454,7 @@ static __m128i lasx_extracti128( __m256i a, int pos)
return ret;
}
static __m128 lasx_extractf128( __m256 a, int pos)
{
static __m128 lasx_extractf128( __m256 a, int pos) {
__m128 ret;
if( pos == 0)
{
@ -481,78 +465,68 @@ static __m128 lasx_extractf128( __m256 a, int pos)
return ret;
}
static __m128i lsx_hadd_h(__m128i a, __m128i b)
{
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}
static __m128i lsx_hadd_w(__m128i a, __m128i b)
{
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}
static __m128 lsx_hadd_s(__m128 a, __m128 b)
{
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
return __lsx_vfadd_s(tmp1, tmp2);
}
static __m256i lasx_maddubs_h(__m256i a, __m256i b)
{
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_h_b(a, b);
tmp2 = __lasx_xvmulwod_h_b(a, b);
return __lasx_xvsadd_h(tmp1, tmp2);
}
static __m256i lasx_madd_h(__m256i a, __m256i b)
{
static __m256i lasx_madd_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_w_h(a, b);
tmp2 = __lasx_xvmulwod_w_h(a, b);
return __lasx_xvadd_w(tmp1, tmp2);
}
static __m256i lasx_packs_w(__m256i a, __m256i b)
{
static __m256i lasx_packs_w(__m256i a, __m256i b) {
__m256i tmp, tmp1;
tmp = __lasx_xvsat_w(a, 15);
tmp1 = __lasx_xvsat_w(b, 15);
return __lasx_xvpickev_h(tmp1, tmp);
}
static __m256i lasx_packs_h(__m256i a, __m256i b)
{
static __m256i lasx_packs_h(__m256i a, __m256i b) {
__m256i tmp, tmp1;
tmp = __lasx_xvsat_h(a, 7);
tmp1 = __lasx_xvsat_h(b, 7);
return __lasx_xvpickev_b(tmp1, tmp);
}
static __m128i lsx_packs_w(__m128i a, __m128i b)
{
static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}
static __m128i lsx_packs_h(__m128i a, __m128i b)
{
static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_packus_h(__m128i a, __m128i b)
{
static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
@ -560,16 +534,14 @@ static __m128i lsx_packus_h(__m128i a, __m128i b)
}
static __m128i lsx_maddubs_h(__m128i a, __m128i b)
{
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}
static __m128i lsx_madd_h(__m128i a, __m128i b)
{
static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
@ -591,7 +563,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
FloatInt tmp;
ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
@ -651,8 +623,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
__m128i hi = __lsx_vsrli_h(lo, 4);
return __lasx_xvandi_b(MM256_SET_M128I(hi, lo), 0xf);
@ -682,8 +653,7 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
return mul_sum_us8_pairs_float(ax, sy);
}
static inline __m128i packNibbles( __m256i bytes )
{
static inline __m128i packNibbles( __m256i bytes ) {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
__m256i high = __lasx_xvandn_v(lowByte, bytes);
@ -1129,7 +1099,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
FloatInt fi;
ft_union fi;
__m256 v0 = (__m256)__lasx_xvld( x , 0);
__m256 v1 = (__m256)__lasx_xvld( x , 32);
__m256 v2 = (__m256)__lasx_xvld( x , 64);
@ -1137,23 +1107,23 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = __lasx_xvreplfr2vr_s( -0.0f );
__m256 maxAbs = (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v0 );
maxAbs = __lasx_xvfmax_s( maxAbs, (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v1 ) );
maxAbs = __lasx_xvfmax_s( maxAbs, (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v2 ) );
maxAbs = __lasx_xvfmax_s( maxAbs, (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v3 ) );
const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
__m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( maxAbs, 1 ), lasx_extractf128( maxAbs , 0) );
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
const float maxScalar = fi.f;
const float max_scalar = fi.f;
// Quantize these floats
const float d = maxScalar / 127.f;
const float d = max_scalar / 127.f;
y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
// Apply the multiplier
@ -1337,12 +1307,12 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
const float max_scalar = _mm_cvtss_f32( max4 );
// Quantize these floats
const float d = maxScalar / 127.f;
const float d = max_scalar / 127.f;
y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
@ -1488,7 +1458,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
FloatInt ft;
ft_union ft;
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
@ -1496,23 +1466,23 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = __lasx_xvreplfr2vr_s( -0.0f );
__m256 maxAbs = (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v0 );
maxAbs = __lasx_xvfmax_s( maxAbs, (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v1 ) );
maxAbs = __lasx_xvfmax_s( maxAbs, (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v2 ) );
maxAbs = __lasx_xvfmax_s( maxAbs, (__m256)__lasx_xvandn_v( (__m256i)signBit, (__m256i)v3 ) );
const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
__m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( maxAbs, 1 ), lasx_extractf128( maxAbs, 0) );
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
const float maxScalar = ft.f;
const float max_scalar = ft.f;
// Quantize these floats
const float d = maxScalar / 127.f;
const float d = max_scalar / 127.f;
y[i].d = GGML_FP32_TO_FP16(d);
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
const __m256 mul = __lasx_xvreplfr2vr_s( id );
// Apply the multiplier
@ -4501,7 +4471,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc);
#elif defined(__loongarch_sx)
// set constants
const __m128i lowMask = __lsx_vreplgr2vr_b(0xF);
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
const __m128i off = __lsx_vreplgr2vr_b(8);
// Initialize accumulator with zeros
@ -4520,30 +4490,27 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[0].qs, 0);
__m128i bx_0 = __lsx_vand_v(lowMask, tmp_0_1);
__m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
__m128i by_0 = __lsx_vld((const __m128i *)y[0].qs, 0);
bx_0 = __lsx_vsub_b(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
__m128i bx_1 = __lsx_vand_v(lowMask, __lsx_vsrli_d(tmp_0_1, 4));
__m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
__m128i by_1 = __lsx_vld((const __m128i *)(y[0].qs + 16), 0);
bx_1 = __lsx_vsub_b(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
//_mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
//_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[1].qs, 0);
__m128i bx_2 = __lsx_vand_v(lowMask, tmp_2_3);
__m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
__m128i by_2 = __lsx_vld((const __m128i *)y[1].qs, 0);
bx_2 = __lsx_vsub_b(bx_2, off);
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
__m128i bx_3 = __lsx_vand_v(lowMask, __lsx_vsrli_d(tmp_2_3, 4));
__m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
__m128i by_3 = __lsx_vld((const __m128i *)(y[1].qs + 16), 0);
bx_3 = __lsx_vsub_b(bx_3, off);
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
@ -4565,20 +4532,18 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
// Main loop
for (int i = 2; i < nb; i+=2) {
//_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
//_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 0 and 1
const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[i].qs, 0);
__m128i bx_0 = __lsx_vand_v(lowMask, tmp_0_1);
__m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
__m128i by_0 = __lsx_vld((const __m128i *)y[i].qs, 0);
bx_0 = __lsx_vsub_b(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
__m128i bx_1 = __lsx_vand_v(lowMask, __lsx_vsrli_d(tmp_0_1, 4));
__m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
__m128i by_1 = __lsx_vld((const __m128i *)(y[i].qs + 16), 0);
bx_1 = __lsx_vsub_b(bx_1, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
@ -4591,12 +4556,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[i + 1].qs, 0);
__m128i bx_2 = __lsx_vand_v(lowMask, tmp_2_3);
__m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
__m128i by_2 = __lsx_vld((const __m128i *)y[i + 1].qs, 0);
bx_2 = __lsx_vsub_b(bx_2, off);
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
__m128i bx_3 = __lsx_vand_v(lowMask, __lsx_vsrli_d(tmp_2_3, 4));
__m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
__m128i by_3 = __lsx_vld((const __m128i *)(y[i + 1].qs + 16), 0);
bx_3 = __lsx_vsub_b(bx_3, off);
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
@ -4896,7 +4861,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); //FIXME
const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
// Compute combined scales
@ -6839,7 +6804,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i p_2 = lasx_ext16_32(lasx_extracti128(p1, 0));
const __m256i p_3 = lasx_ext16_32(lasx_extracti128(p1, 1));
FloatInt t0, t1, t2, t3;
ft_union t0, t1, t2, t3;
t0.f = d * db[0];
t1.f = d * db[1];
t2.f = d * db[2];
@ -7595,12 +7560,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
p16_0 = __lasx_xvadd_w(p16_0, p16_1);
p16_2 = __lasx_xvadd_w(p16_2, p16_3);
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
}
// multiply with block scale and accumulate
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
}
*s = hsum_float_8(acc);
@ -8123,7 +8085,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
// multiply with block scale and accumulate
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(p16_0), acc);
}
*s = hsum_float_8(acc);
@ -8668,21 +8629,18 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
__m256 vd = __lasx_xvreplfr2vr_s(d);
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
}
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
__m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
FloatInt fi;
ft_union fi;
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
*s = hsum_float_8(acc) + fi.f ;
#else
const uint8_t * scales = (const uint8_t*)&utmp[0];
const uint8_t * mins = (const uint8_t*)&utmp[2];
@ -9068,7 +9026,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i p32h = lasx_madd_h(__lasx_xvreplgr2vr_h(scales[1]), p16h);
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(p32h), acc);
}
*s = hsum_float_8(acc) - summs;
@ -9668,12 +9625,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
p16_1 = lasx_madd_h(scale_1, p16_1);
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
}
__m256 vd = __lasx_xvreplfr2vr_s(d);
acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
}
*s = hsum_float_8(acc) + summs;
@ -10094,7 +10049,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m256i dot = __lasx_xvsub_w(__lasx_xvadd_w(p16_0, p16_1), __lasx_xvadd_w(s16_0, s16_1));
acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(dot), acc);
}
*s = hsum_float_8(acc);
@ -10745,7 +10699,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
}
acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
@ -11507,7 +11460,6 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
}
accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
@ -11857,7 +11809,6 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m256i sum = __lasx_xvadd_w(lasx_madd_h(sc1, dot1), lasx_madd_h(sc2, dot2));
accumf = __lasx_vfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sum), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
@ -12450,7 +12401,6 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
}
accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
}
*s = 0.125f * hsum_float_8(accumf);
@ -12738,7 +12688,6 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
}
accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
}
*s = 0.25f * hsum_float_8(accumf);
@ -13167,7 +13116,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
}
accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
}
*s = hsum_float_8(accumf);
@ -13462,7 +13410,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
__m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
@ -13496,7 +13443,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
accum1 += d * sumi1;
}
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;