Arm AArch64: minor changes to skip the pr#7433 vec_dot code for arm cpus with SVE VL not equal to 256 bits
This commit is contained in:
parent
e2c1c47fa8
commit
79b6cdfe69
2 changed files with 55 additions and 48 deletions
|
@ -3814,43 +3814,47 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
if (svcntb() == QK8_0) {
|
||||||
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
||||||
|
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
||||||
|
|
||||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||||
|
|
||||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||||
|
|
||||||
for (int i = 0; i < nb; i += 2) {
|
for (int i = 0; i < nb; i += 2) {
|
||||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||||
|
|
||||||
// load x
|
// load x
|
||||||
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
||||||
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
||||||
|
|
||||||
// 4-bit -> 8-bit
|
// 4-bit -> 8-bit
|
||||||
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
||||||
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
||||||
|
|
||||||
// sub 8
|
// sub 8
|
||||||
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
||||||
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||||
|
|
||||||
// dot product
|
// dot product
|
||||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
#if defined(__ARM_NEON)
|
||||||
#elif defined(__ARM_NEON)
|
|
||||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||||
|
|
||||||
|
@ -5422,31 +5426,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
if (svcntb() == QK8_0) {
|
||||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||||
|
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||||
|
|
||||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||||
|
|
||||||
for (int i = 0; i < nb; i += 2) {
|
for (int i = 0; i < nb; i += 2) {
|
||||||
const block_q8_0 * restrict x0 = &x[i + 0];
|
const block_q8_0 * restrict x0 = &x[i + 0];
|
||||||
const block_q8_0 * restrict x1 = &x[i + 1];
|
const block_q8_0 * restrict x1 = &x[i + 1];
|
||||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||||
|
|
||||||
// load x
|
// load x
|
||||||
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
||||||
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
||||||
|
|
||||||
// load y
|
// load y
|
||||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||||
|
|
||||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
#if defined(__ARM_NEON)
|
||||||
#elif defined(__ARM_NEON)
|
|
||||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||||
|
|
||||||
|
|
|
@ -21901,7 +21901,6 @@ int ggml_cpu_has_neon(void) {
|
||||||
int ggml_cpu_has_sve(void) {
|
int ggml_cpu_has_sve(void) {
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
// TODO: Currently, SVE 256 bit is only supported.
|
// TODO: Currently, SVE 256 bit is only supported.
|
||||||
GGML_ASSERT(svcntb() == QK8_0);
|
|
||||||
return 1;
|
return 1;
|
||||||
#else
|
#else
|
||||||
return 0;
|
return 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue