ggml : optimize and build warning fix for LoongArch (#11709)

* ggml : optimize convert f32<->f16 for loongarch_asx

* ggml : optimize loongarch_asx extend i16,i8,u8 to i32,i16

* ggml : Fix warnings when run cpu CI locally on LoongArch
This commit is contained in:
Jinyang He 2025-02-07 15:38:31 +08:00 committed by GitHub
parent 855cd0734a
commit 225bbbfa39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 22 additions and 57 deletions

View file

@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
}
static __m256i lasx_extu8_16(__m128i a) {
__m128i zero = __lsx_vldi(0);
__m128i vlo = __lsx_vilvl_b(zero, a);
__m128i vhi = __lsx_vilvh_b(zero, a);
return lasx_set_q(vhi, vlo);
return __lasx_vext2xv_hu_bu(____m256i(a));
}
static __m256i lasx_ext8_16(__m128i a) {
__m128i sign = __lsx_vslti_b(a, 0);
__m128i vlo = __lsx_vilvl_b(sign, a);
__m128i vhi = __lsx_vilvh_b(sign, a);
return lasx_set_q(vhi, vlo);
return __lasx_vext2xv_h_b(____m256i(a));
}
static __m256i lasx_ext16_32(__m128i a) {
__m256i tmp1;
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
return tmp1;
return __lasx_vext2xv_w_h(____m256i(a));
}
static __m128i lasx_extracti128( __m256i a, int pos) {
@ -592,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
tmp.i = __lsx_vpickve2gr_w(res, 0);
return tmp.f;
return ((v4f32)res)[0];
}
// horizontally add 8 int32_t
@ -939,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
ft_union fi;
__m256 v0 = (__m256)__lasx_xvld( x , 0);
__m256 v1 = (__m256)__lasx_xvld( x , 32);
__m256 v2 = (__m256)__lasx_xvld( x , 64);
@ -957,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
const float max_scalar = fi.f;
const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats
const float d = max_scalar / 127.f;
@ -1263,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
ft_union ft;
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
@ -1281,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
const float max_scalar = ft.f;
const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats
const float d = max_scalar / 127.f;
@ -6154,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
ft_union fi;
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
*s = hsum_float_8(acc) + fi.f ;
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
#else
const uint8_t * scales = (const uint8_t*)&utmp[0];