diff --git a/ggml-phi-knc.c b/ggml-phi-knc.c index 5f9eb70c4..9b53d876d 100644 --- a/ggml-phi-knc.c +++ b/ggml-phi-knc.c @@ -31,7 +31,7 @@ inline static void GGML_F32x16_VEC_ZERO(float32x16_t *target) : "zmm8", "memory"); } -// Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. optionally clear the sum before starting. +// Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. Optionally clear the sum before starting. inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x16_t *mvec2, float32x16_t *sumvec, size_t iterations, int clear) { uint8_t zero = 0; @@ -59,8 +59,8 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x "vprefetchnta\t(%%r10)\n\t" "vprefetchnta\t(%%r12)\n\t" "5:\n\t" - "cmp\t$3,\t%%r8\n\t" // Compare iterations to three. - "jnae\t6f\n\t" // If there are not three iterations left, jump to label 6. + "cmp\t$4,\t%%r8\n\t" // Compare iterations to four. + "jnae\t6f\n\t" // If there are not four iterations left, jump to label 6. "1:\n\t" "sub\t$3,\t%%r8\n\t" // Decrement iterations "vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors. @@ -79,12 +79,15 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x "vprefetch1\t576(%%r12)\n\t" "vprefetch1\t704(%%r10)\n\t" "vprefetch1\t704(%%r12)\n\t" - "add\t$192,\t%%r10\n\t" // Move to the next float32x16_t block (192 bytes ahead) - "add\t$192,\t%%r12\n\t" "vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add - "cmp\t$3,\t%%r8\n\t" // Compare iteration count to three. - "jge\t1b\n\t" // If there three or more iterations left, loop. - "6:\n\t" // We know we are near the tail. handle 2, 1, and 0 cases. + "vmovaps\t\t192(%%r10),\t%%zmm7\n\t" // Load two vectors. + "vmovaps\t\t192(%%r12),\t%%zmm8\n\t" + "vfmadd231ps\t%%zmm7,\t%%zmm8,\t%%zmm0\n\t" // Perform a fused multiply add + "add\t$256,\t%%r10\n\t" // Move to the next 4xfloat32x16_t block (256 bytes ahead) + "add\t$256,\t%%r12\n\t" + "cmp\t$4,\t%%r8\n\t" // Compare iteration count to four. + "jge\t1b\n\t" // If there are four or more iterations left, loop. + "6:\n\t" // We know we are near the tail. handle 3, 2, 1, and 0 cases. "cmp\t$0,\t%%r8\n\t" // Compare iterations to zero "jz\t2f\n\t" // Jump to label 2 if zero (end of loop) "cmp\t$1,\t%%r8\n\t" // Compare iterations to one @@ -92,10 +95,14 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x "vmovaps\t\t(%%r12),\t%%zmm2\n\t" "vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add "je\t2f\n\t" // Jump to label 2 if one (end of loop) - // No compare. we must be two. + "cmp\t$2,\t%%r8\n\t" // Compare iterations to two "vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors. "vmovaps\t\t64(%%r12),\t%%zmm4\n\t" "vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add + // No compare. we must be three. + "vmovaps\t\t64(%%r10),\t%%zmm5\n\t" // Load two vectors. + "vmovaps\t\t64(%%r12),\t%%zmm6\n\t" + "vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add "2:\n\t" // Label for loop end "vmovnraps\t\t%%zmm0,\t(%[RES])\n\t" // Save our results. : [RES] "+r" (sumvec) @@ -104,7 +111,7 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x [VEC2] "r" (mvec2), [CLR] "r" (clear), [Z] "m" (zero) - : "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "cc", "memory", "r8", "r10", "r12"); + : "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "memory", "r8", "r10", "r12"); } // Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. uses masks to handle just the last run-through.