From 4dbdb6c82f8f3790160cb7ddfd1eb9cc2622b498 Mon Sep 17 00:00:00 2001 From: vithulep Date: Tue, 3 Sep 2024 11:27:22 +0530 Subject: [PATCH] Implemented vector length agnostic SVE using switch case for 512-bit, 256-bit, 128-bit vector lengths --- ggml/src/ggml-quants.c | 256 +++++++++++++++++++++++++++++++++++------ 1 file changed, 223 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48b90f01b..8b8440cbc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3818,14 +3818,20 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r float sumf = 0; #if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + assert(nb % 2 == 0); // TODO: handle odd nb + const int vector_length = ggml_sve_cnt_b*8; - for (; ib + 1 < nb; ib += 2) { + // VLA Implementation using switch case + switch(vector_length) + { + case 128: + // predicate for activating higher lanes for 4 float32 elements + const svbool_t pg =svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { const block_q4_0 * restrict x0 = &x[ib + 0]; const block_q4_0 * restrict x1 = &x[ib + 1]; const block_q8_0 * restrict y0 = &y[ib + 0]; @@ -3836,24 +3842,113 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(),qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(),qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs+16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs+16); // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + + sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0ls, qy0l),svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1ls, qy1l),svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + + break; + + case 256: + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ptrueh_256 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t ptruel_256 = svnot_b_z(svptrue_b8(), ptrueh_256); + + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel_256, svand_n_u8_m(ptrueh_256, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + + break; + + case 512: + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ptrue = svptrue_pat_b8(SV_VL32); + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t ptruel = svnot_b_z(ptrue, ptrueh); + + for (; ib < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ptrue, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ptrue, x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ptrue, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ptrue, qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(ptrue, y0->qs); + const svint8_t qy1 = svld1_s8(ptrue, y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(ptrue, sumv0, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ptrue, sumv1, svcvt_f32_s32_x(ptrue, svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + sumf = svaddv_f32(ptrue, svadd_f32_x(ptrue, sumv0, sumv1)); + break; + + default: + break; + } + #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -5303,29 +5398,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r float sumf = 0; #if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + assert(nb % 2 == 0); // TODO: handle odd nb + const int vector_length = ggml_sve_cnt_b*8; - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + //VLA Implemenation for SVE + switch(vector_length) + { + case 128: + // predicate for activating lanes for 16 Int8 elements + svbool_t pg1 =svptrue_pat_b8(SV_VL16); + svbool_t pg =svptrue_pat_b32(SV_VL4); + for (; ib + 1 < nb; ib += 2) { - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0_0 = svld1_s8(pg1, x0->qs); + const svint8_t qx0_1 = svld1_s8(pg1, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(pg1, x1->qs); + const svint8_t qx1_1 = svld1_s8(pg1, x1->qs+16); + + // load y + const svint8_t qy0_0 = svld1_s8(pg1, y0->qs); + const svint8_t qy0_1 = svld1_s8(pg1, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(pg1, y1->qs); + const svint8_t qy1_1 = svld1_s8(pg1, y1->qs+16); + + sumv0 = svmla_n_f32_x(pg, sumv0, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pg, sumv1, svcvt_f32_s32_x(pg, svadd_x(pg,svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + + } + + sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); + break; + + case 256: + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + break; + + case 512: + // predicate for activating high 256 bit + const svbool_t ptrueh = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + + // predicate for activating high lanes for 8 float32 elements + svbool_t asd = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + svbool_t dsa = svnot_b_z(svptrue_b32(), asd); + + svfloat32_t sumv00 = svdup_n_f32(0.0f); + + for (; ib+1 < nb; ib += 2) { + + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ptrueh,x0->qs); + svint8_t qx_64 = svld1_s8(ptruel,x0->qs+2); + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + + // load y + const svint8_t qy_32 = svld1_s8(ptrueh,y0->qs); + svint8_t qy_64 = svld1_s8(ptruel,y0->qs+2); + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + + // scale creation + float32_t deq1= GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); + float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); + + // duplicate deq1 in first half of vector and deq2 in second half of vector + svfloat32_t temp = svdup_f32_m(svdup_f32_z(asd, deq1), dsa,deq2); + + + svfloat32_t sumvt = svdup_n_f32(0.0f); + + sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + + sumv00 = svmla_f32_m(svptrue_b32(),sumv00,sumvt,temp); + + } + + sumf = svaddv_f32(svptrue_b32(), sumv00); + break; + + default: + break; - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); } #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f);