From dda250f637bd3a3edf18c63fa16fc2d10a0efbb9 Mon Sep 17 00:00:00 2001 From: Julia Longtin Date: Fri, 10 May 2024 17:03:41 +0000 Subject: [PATCH] move sub earlier, and move the compare of iterations to outside, and at the end of the loop. --- ggml-phi-knc.c | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml-phi-knc.c b/ggml-phi-knc.c index 91289bd04..5e400849a 100644 --- a/ggml-phi-knc.c +++ b/ggml-phi-knc.c @@ -52,18 +52,19 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x "cmp\t$0,%[CLR]\n\t" // Should we clear the sum before we start? "jz\t4f\n\t" "vbroadcastss\t%[Z]%{uint8%},\t%%zmm0\n\t" // If so, use an upscaling operator to clear our sum. - "jmp\t1f\n\t" + "jmp\t5f\n\t" "4:\n\t" "vprefetch0\t(%[RES])\n\t" "vmovaps\t\t(%[RES]),\t%%zmm0\n\t" // Otherwise, load our inital state from sum.. "vprefetchnta\t(%%r10)\n\t" "vprefetchnta\t(%%r12)\n\t" - "1:\n\t" + "5:\n\t" "cmp\t$3,\t%%r8\n\t" // Compare iterations to three. "jnae\t6f\n\t" // If there are not three iterations left, jump to label 6. + "1:\n\t" + "sub\t$3,\t%%r8\n\t" // Decrement iterations "vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors. "vmovaps\t\t(%%r12),\t%%zmm2\n\t" - "sub\t$3,\t%%r8\n\t" // Decrement iterations "vprefetchnta\t192(%%r10)\n\t" // prefetch the next float32x16_t block (192 bytes ahead) "vprefetchnta\t192(%%r12)\n\t" "vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors. @@ -81,7 +82,8 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x "vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add "vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add "vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add - "jmp\t1b\n\t" // Jump back to the start of the loop + "cmp\t$3,\t%%r8\n\t" // Compare iterations to three. + "jnae\t6f\n\t" // If there are not three iterations left, jump to label 6. "6:\n\t" // We know we are near the tail. handle 2, 1, and 0 cases. "cmp\t$0,\t%%r8\n\t" // Compare iterations to zero "je\t2f\n\t" // Jump to label 2 if zero (end of loop)