fix bad merging
This commit is contained in:
parent
610b3ac3cd
commit
e5aeb423a5
1 changed files with 167 additions and 1 deletions
|
@ -5986,7 +5986,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
|
|
||||||
uint32_t utmp[4];
|
uint32_t utmp[4];
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#ifdef __ARM_FEATURE_SVE
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||||
|
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
||||||
|
|
||||||
|
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
||||||
|
|
||||||
|
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
|
||||||
|
|
||||||
|
uint32x2_t mins8 = { 0 };
|
||||||
|
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
|
||||||
|
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
|
||||||
|
|
||||||
|
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||||
|
utmp[0] &= kmask1;
|
||||||
|
|
||||||
|
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
||||||
|
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
||||||
|
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
||||||
|
sumf -= dmin * vaddvq_s32(prod);
|
||||||
|
|
||||||
|
const uint8_t * scales = (const uint8_t *)utmp;
|
||||||
|
|
||||||
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
||||||
|
const svuint8_t m4b = svdup_n_u8(0xf);
|
||||||
|
const svint32_t mzero = svdup_n_s32(0);
|
||||||
|
svint32_t sumi1 = svdup_n_s32(0);
|
||||||
|
svint32_t sumi1_1 = svdup_n_s32(0);
|
||||||
|
svint32_t sumi1_2 = svdup_n_s32(0);
|
||||||
|
svint32_t sumi2 = svdup_n_s32(0);
|
||||||
|
svint32_t sumi2_1 = svdup_n_s32(0);
|
||||||
|
svint32_t sumi2_2 = svdup_n_s32(0);
|
||||||
|
switch (vector_length) {
|
||||||
|
case 128:
|
||||||
|
{
|
||||||
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
|
||||||
|
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
||||||
|
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
||||||
|
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
|
||||||
|
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
||||||
|
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
||||||
|
|
||||||
|
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
|
||||||
|
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
||||||
|
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
||||||
|
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
|
||||||
|
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
||||||
|
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
||||||
|
q4 += 32;
|
||||||
|
}
|
||||||
|
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
|
||||||
|
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
|
||||||
|
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
|
||||||
|
} break;
|
||||||
|
case 256:
|
||||||
|
case 512:
|
||||||
|
{
|
||||||
|
for (int j = 0; j < QK_K/64; ++j) {
|
||||||
|
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
|
||||||
|
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
|
||||||
|
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
||||||
|
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
||||||
|
|
||||||
|
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
|
||||||
|
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
||||||
|
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
||||||
|
}
|
||||||
|
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported vector length");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*s = sumf;
|
||||||
|
#elif defined __ARM_NEON
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
const int32x4_t mzero = vdupq_n_s32(0);
|
const int32x4_t mzero = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
@ -7756,6 +7837,91 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
||||||
}
|
}
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
|
|
||||||
|
#elif defined __riscv_v_intrinsic
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
|
||||||
|
const uint8_t * restrict q6 = x[i].ql;
|
||||||
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
|
||||||
|
const int8_t * restrict scale = x[i].scales;
|
||||||
|
|
||||||
|
size_t vl;
|
||||||
|
|
||||||
|
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||||
|
|
||||||
|
int sum_t = 0;
|
||||||
|
int is = 0;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
|
vl = 32;
|
||||||
|
|
||||||
|
// load qh
|
||||||
|
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
|
||||||
|
|
||||||
|
// load Q6
|
||||||
|
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
|
||||||
|
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
|
||||||
|
|
||||||
|
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
|
||||||
|
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
|
||||||
|
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
|
||||||
|
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
|
||||||
|
|
||||||
|
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
|
||||||
|
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
|
||||||
|
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
|
||||||
|
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
|
||||||
|
|
||||||
|
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
|
||||||
|
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
|
||||||
|
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
|
||||||
|
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
|
||||||
|
|
||||||
|
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
|
||||||
|
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
|
||||||
|
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
|
||||||
|
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
|
||||||
|
|
||||||
|
// load Q8 and take product
|
||||||
|
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
|
||||||
|
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
|
||||||
|
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
|
||||||
|
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
||||||
|
|
||||||
|
vl = 16;
|
||||||
|
|
||||||
|
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
|
||||||
|
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
|
||||||
|
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
|
||||||
|
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
|
||||||
|
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
|
||||||
|
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
|
||||||
|
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
|
||||||
|
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
|
||||||
|
|
||||||
|
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
|
||||||
|
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
|
||||||
|
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
|
||||||
|
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
|
||||||
|
|
||||||
|
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
|
||||||
|
|
||||||
|
q6 += 64; qh += 32; q8 += 128; is=8;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf += d * sum_t;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
#elif defined(__POWER9_VECTOR__)
|
#elif defined(__POWER9_VECTOR__)
|
||||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||||
const vector int v0 = vec_splats((int32_t)0);
|
const vector int v0 = vec_splats((int32_t)0);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue