From 4dbdb6c82f8f3790160cb7ddfd1eb9cc2622b498 Mon Sep 17 00:00:00 2001 From: vithulep Date: Tue, 3 Sep 2024 11:27:22 +0530 Subject: [PATCH 1/4] Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths --- ggml/src/ggml-quants.c | 256 +++++++++++++++++++++++++++++++++++------ 1 file changed, 223 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48b90f01b..8b8440cbc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3818,14 +3818,20 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r float sumf = 0; #if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + assert(nb % 2 == 0); // TODO: handle odd nb + const int vector_length = ggml_sve_cnt_b*8; - for (; ib + 1 < nb; ib += 2) { + // VLA Implementation using switch case + switch(vector_length) + { + case 128: + // predicate for activating higher lanes for 4 float32 elements + const svbool_t pg =svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { const block_q4_0 * restrict x0 = &x[ib + 0]; const block_q4_0 * restrict x1 = &x[ib + 1]; const block_q8_0 * restrict y0 = &y[ib + 0]; @@ -3836,24 +3842,113 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(),qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs+16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs+16); // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + + sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0ls, qy0l),svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1ls, qy1l),svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + + break; + + case 256: + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ptrueh_256 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t ptruel_256 = svnot_b_z(svptrue_b8(), ptrueh_256); + + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + + break; + + case 512: + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ptrue = svptrue_pat_b8(SV_VL32); + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t ptruel = svnot_b_z(ptrue, ptrueh); + + for (; ib < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ptrue, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ptrue, x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ptrue, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ptrue, qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(ptrue, y0->qs); + const svint8_t qy1 = svld1_s8(ptrue, y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(ptrue, sumv0, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ptrue, sumv1, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + sumf = svaddv_f32(ptrue, svadd_f32_x(ptrue, sumv0, sumv1)); + break; + + default: + break; + } + #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -5303,29 +5398,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r float sumf = 0; #if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + assert(nb % 2 == 0); // TODO: handle odd nb + const int vector_length = ggml_sve_cnt_b*8; - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + //VLA Implemenation for SVE + switch(vector_length) + { + case 128: + // predicate for activating lanes for 16 Int8 elements + svbool_t pg1 =svptrue_pat_b8(SV_VL16); + svbool_t pg =svptrue_pat_b32(SV_VL4); + for (; ib + 1 < nb; ib += 2) { - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0_0 = svld1_s8(pg1, x0->qs); + const svint8_t qx0_1 = svld1_s8(pg1, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(pg1, x1->qs); + const svint8_t qx1_1 = svld1_s8(pg1, x1->qs+16); + + // load y + const svint8_t qy0_0 = svld1_s8(pg1, y0->qs); + const svint8_t qy0_1 = svld1_s8(pg1, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(pg1, y1->qs); + const svint8_t qy1_1 = svld1_s8(pg1, y1->qs+16); + + sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + + } + + sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); + break; + + case 256: + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + break; + + case 512: + // predicate for activating high 256 bit + const svbool_t ptrueh = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + + // predicate for activating high lanes for 8 float32 elements + svbool_t asd = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + svbool_t dsa = svnot_b_z(svptrue_b32(), asd); + + svfloat32_t sumv00 = svdup_n_f32(0.0f); + + for (; ib+1 < nb; ib += 2) { + + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ptrueh,x0->qs); + svint8_t qx_64 = svld1_s8(ptruel,x0->qs+2); + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + + // load y + const svint8_t qy_32 = svld1_s8(ptrueh,y0->qs); + svint8_t qy_64 = svld1_s8(ptruel,y0->qs+2); + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + + // scale creation + float32_t deq1= GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); + float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); + + // duplicate deq1 in first half of vector and deq2 in second half of vector + svfloat32_t temp = svdup_f32_m(svdup_f32_z(asd, deq1), dsa,deq2); + + + svfloat32_t sumvt = svdup_n_f32(0.0f); + + sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + + sumv00 = svmla_f32_m(svptrue_b32(),sumv00,sumvt,temp); + + } + + sumf = svaddv_f32(svptrue_b32(), sumv00); + break; + + default: + break; - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); } #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); From 6a6cfd6c6f34747142f155192df42e11cebec026 Mon Sep 17 00:00:00 2001 From: vithulep Date: Tue, 3 Sep 2024 12:17:44 +0530 Subject: [PATCH 2/4] Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths --- ggml/src/ggml-quants.c | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 8b8440cbc..5ebdf96d1 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3821,7 +3821,6 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb const int vector_length = ggml_sve_cnt_b*8; // VLA Implementation using switch case @@ -5402,7 +5401,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb const int vector_length = ggml_sve_cnt_b*8; //VLA Implemenation for SVE From a9a9f6669241a19ac4ac3ee8ce1a1214ab35e9a5 Mon Sep 17 00:00:00 2001 From: vithulep Date: Tue, 3 Sep 2024 14:10:39 +0530 Subject: [PATCH 3/4] Removed WhiteSpaces --- ggml/src/ggml-quants.c | 74 +++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 5ebdf96d1..81c814f9b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3823,13 +3823,13 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv1 = svdup_n_f32(0.0f); const int vector_length = ggml_sve_cnt_b*8; - // VLA Implementation using switch case + // VLA Implementation using switch case switch(vector_length) - { + { case 128: // predicate for activating higher lanes for 4 float32 elements const svbool_t pg =svptrue_pat_b32(SV_VL4); - + for (; ib + 1 < nb; ib += 2) { const block_q4_0 * restrict x0 = &x[ib + 0]; const block_q4_0 * restrict x1 = &x[ib + 1]; @@ -3866,15 +3866,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + break; - case 256: + case 256: // predicate for activating higher lanes for 16 int8 elements const svbool_t ptrueh_256 = svptrue_pat_b8(SV_VL16); // predicate for activating lower lanes for 16 int8 elements const svbool_t ptruel_256 = svnot_b_z(svptrue_b8(), ptrueh_256); - + for (; ib + 1 < nb; ib += 2) { const block_q4_0 * restrict x0 = &x[ib + 0]; const block_q4_0 * restrict x1 = &x[ib + 1]; @@ -3904,14 +3904,14 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + break; case 512: // predicate for activating higher lanes for 32 int8 elements const svbool_t ptrue = svptrue_pat_b8(SV_VL32); - // predicate for activating higher lanes for 16 int8 elements + // predicate for activating higher lanes for 16 int8 elements const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes const svbool_t ptruel = svnot_b_z(ptrue, ptrueh); for (; ib < nb; ib += 2) { @@ -3942,9 +3942,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = svaddv_f32(ptrue, svadd_f32_x(ptrue, sumv0, sumv1)); break; - - default: - break; + + default: + break; } @@ -5403,20 +5403,20 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const int vector_length = ggml_sve_cnt_b*8; - //VLA Implemenation for SVE + //VLA Implemenation for SVE switch(vector_length) - { + { case 128: // predicate for activating lanes for 16 Int8 elements svbool_t pg1 =svptrue_pat_b8(SV_VL16); svbool_t pg =svptrue_pat_b32(SV_VL4); - for (; ib + 1 < nb; ib += 2) { + for (; ib + 1 < nb; ib += 2) { const block_q8_0 * restrict x0 = &x[ib + 0]; const block_q8_0 * restrict x1 = &x[ib + 1]; const block_q8_0 * restrict y0 = &y[ib + 0]; const block_q8_0 * restrict y1 = &y[ib + 1]; - + // load x const svint8_t qx0_0 = svld1_s8(pg1, x0->qs); const svint8_t qx0_1 = svld1_s8(pg1, x0->qs+16); @@ -5434,11 +5434,11 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } - sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); - break; + sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); + break; case 256: - //printf("sve256"); + //printf("sve256"); for (; ib + 1 < nb; ib += 2) { const block_q8_0 * restrict x0 = &x[ib + 0]; const block_q8_0 * restrict x1 = &x[ib + 1]; @@ -5452,24 +5452,24 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r // load y const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - - } - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + break; case 512: - // predicate for activating high 256 bit + // predicate for activating high 256 bit const svbool_t ptrueh = svptrue_pat_b8(SV_VL32); - // predicate for activating low 256 bit + // predicate for activating low 256 bit const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - + // predicate for activating high lanes for 8 float32 elements svbool_t asd = svptrue_pat_b32(SV_VL8); // predicate for activating low lanes for 8 float32 elements - svbool_t dsa = svnot_b_z(svptrue_b32(), asd); + svbool_t dsa = svnot_b_z(svptrue_b32(), asd); svfloat32_t sumv00 = svdup_n_f32(0.0f); @@ -5480,9 +5480,9 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const block_q8_0 * restrict y0 = &y[ib + 0]; const block_q8_0 * restrict y1 = &y[ib + 1]; - //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits // and add them to make one 64 element vector - // load x + // load x const svint8_t qx_32 = svld1_s8(ptrueh,x0->qs); svint8_t qx_64 = svld1_s8(ptruel,x0->qs+2); qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); @@ -5491,11 +5491,11 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const svint8_t qy_32 = svld1_s8(ptrueh,y0->qs); svint8_t qy_64 = svld1_s8(ptruel,y0->qs+2); qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); - - // scale creation + + // scale creation float32_t deq1= GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); - + // duplicate deq1 in first half of vector and deq2 in second half of vector svfloat32_t temp = svdup_f32_m(svdup_f32_z(asd, deq1), dsa,deq2); @@ -5503,16 +5503,16 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumvt = svdup_n_f32(0.0f); sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); - + sumv00 = svmla_f32_m(svptrue_b32(),sumv00,sumvt,temp); } - sumf = svaddv_f32(svptrue_b32(), sumv00); + sumf = svaddv_f32(svptrue_b32(), sumv00); break; - default: - break; + default: + break; } #elif defined(__ARM_NEON) From cfbf33a7052d268266d53a55b8c4c6aca9cecde4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 9 Sep 2024 12:50:35 +0300 Subject: [PATCH 4/4] ggml : style changes + fix 512-bit nb loop check - fix local scope in switch cases - consistent predicate names - empty lines when necessary - opening braces, spaces - const-correctness - add asserts --- ggml/src/ggml-quants.c | 374 +++++++++++++++++++++-------------------- 1 file changed, 191 insertions(+), 183 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 81c814f9b..059347bc7 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3818,134 +3818,139 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r float sumf = 0; #if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); + const int vector_length = ggml_sve_cnt_b*8; // VLA Implementation using switch case - switch(vector_length) - { + switch (vector_length) { case 128: - // predicate for activating higher lanes for 4 float32 elements - const svbool_t pg =svptrue_pat_b32(SV_VL4); + { + // predicate for activating higher lanes for 4 float32 elements + const svbool_t ph4 = svptrue_pat_b32(SV_VL4); - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - // 4-bit -> 8-bit - const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx0r, 0x0F)); - const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(),qx0r, 0x04)); - const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx1r, 0x0F)); - const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); + // 4-bit -> 8-bit + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); - // sub 8 - const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); - const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); - const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); - const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); + // sub 8 + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); - // load y - const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs+16); - const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); - const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs+16); - // dot product + // load y + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); - sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0ls, qy0l),svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1ls, qy1l),svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - - break; + // dot product + sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx0ls, qy0l), + svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx1ls, qy1l), + svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; case 256: - // predicate for activating higher lanes for 16 int8 elements - const svbool_t ptrueh_256 = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements - const svbool_t ptruel_256 = svnot_b_z(svptrue_b8(), ptrueh_256); + { + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx1r, 0x0F), 0x04)); + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - - break; + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; case 512: - // predicate for activating higher lanes for 32 int8 elements - const svbool_t ptrue = svptrue_pat_b8(SV_VL32); - // predicate for activating higher lanes for 16 int8 elements - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes - const svbool_t ptruel = svnot_b_z(ptrue, ptrueh); + { + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); - for (; ib < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t pl16 = svnot_b_z(ph32, ph16); - // load x - const svuint8_t qx0r = svld1rq_u8(ptrue, x0->qs); - const svuint8_t qx1r = svld1rq_u8(ptrue, x1->qs); + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + // load x + const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(ptrue, qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(ptrue, qx1, 8); + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); - // load y - const svint8_t qy0 = svld1_s8(ptrue, y0->qs); - const svint8_t qy1 = svld1_s8(ptrue, y1->qs); + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); - // dot product - sumv0 = svmla_n_f32_x(ptrue, sumv0, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(ptrue, sumv1, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - sumf = svaddv_f32(ptrue, svadd_f32_x(ptrue, sumv0, sumv1)); - break; + // load y + const svint8_t qy0 = svld1_s8(ph32, y0->qs); + const svint8_t qy1 = svld1_s8(ph32, y1->qs); + // dot product + sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); + } break; default: - break; - + assert(false && "Unsupported vector length"); + break; } #elif defined(__ARM_NEON) @@ -5397,123 +5402,126 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r float sumf = 0; #if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); svfloat32_t sumv1 = svdup_n_f32(0.0f); const int vector_length = ggml_sve_cnt_b*8; //VLA Implemenation for SVE - switch(vector_length) - { + switch (vector_length) { case 128: - // predicate for activating lanes for 16 Int8 elements - svbool_t pg1 =svptrue_pat_b8(SV_VL16); - svbool_t pg =svptrue_pat_b32(SV_VL4); - for (; ib + 1 < nb; ib += 2) { + { + // predicate for activating lanes for 16 Int8 elements + const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); + const svbool_t pl16 = svptrue_pat_b32(SV_VL4); - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - // load x - const svint8_t qx0_0 = svld1_s8(pg1, x0->qs); - const svint8_t qx0_1 = svld1_s8(pg1, x0->qs+16); - const svint8_t qx1_0 = svld1_s8(pg1, x1->qs); - const svint8_t qx1_1 = svld1_s8(pg1, x1->qs+16); + // load x + const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); + const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); + const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); - // load y - const svint8_t qy0_0 = svld1_s8(pg1, y0->qs); - const svint8_t qy0_1 = svld1_s8(pg1, y0->qs+16); - const svint8_t qy1_0 = svld1_s8(pg1, y1->qs); - const svint8_t qy1_1 = svld1_s8(pg1, y1->qs+16); + // load y + const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); + const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); + const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); - sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - - } - - sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); - break; + sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), + svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), + svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); + } break; case 256: - //printf("sve256"); - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + { + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - - } - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; case 512: - // predicate for activating high 256 bit - const svbool_t ptrueh = svptrue_pat_b8(SV_VL32); - // predicate for activating low 256 bit - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + { + // predicate for activating high 256 bit + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); - // predicate for activating high lanes for 8 float32 elements - svbool_t asd = svptrue_pat_b32(SV_VL8); - // predicate for activating low lanes for 8 float32 elements - svbool_t dsa = svnot_b_z(svptrue_b32(), asd); + // predicate for activating high lanes for 8 float32 elements + const svbool_t ph8 = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); - svfloat32_t sumv00 = svdup_n_f32(0.0f); + svfloat32_t sumv00 = svdup_n_f32(0.0f); - for (; ib+1 < nb; ib += 2) { + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ph32, x0->qs); + svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); - //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits - // and add them to make one 64 element vector - // load x - const svint8_t qx_32 = svld1_s8(ptrueh,x0->qs); - svint8_t qx_64 = svld1_s8(ptruel,x0->qs+2); - qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); - // load y - const svint8_t qy_32 = svld1_s8(ptrueh,y0->qs); - svint8_t qy_64 = svld1_s8(ptruel,y0->qs+2); - qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + // load y + const svint8_t qy_32 = svld1_s8(ph32, y0->qs); + svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); - // scale creation - float32_t deq1= GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); - float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); - // duplicate deq1 in first half of vector and deq2 in second half of vector - svfloat32_t temp = svdup_f32_m(svdup_f32_z(asd, deq1), dsa,deq2); + // scale creation + const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); + const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); + // duplicate deq1 in first half of vector and deq2 in second half of vector + const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); - svfloat32_t sumvt = svdup_n_f32(0.0f); + svfloat32_t sumvt = svdup_n_f32(0.0f); - sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); - sumv00 = svmla_f32_m(svptrue_b32(),sumv00,sumvt,temp); + sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); + } - } - - sumf = svaddv_f32(svptrue_b32(), sumv00); - break; - - default: + sumf = svaddv_f32(svptrue_b32(), sumv00); break; - + } + default: + assert(false && "Unsupported vector length"); + break; } #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f);