change from handling three iterations per loop to four.
This commit is contained in:
parent
a82ada7dcd
commit
3156e639bf
1 changed files with 17 additions and 10 deletions
|
@ -31,7 +31,7 @@ inline static void GGML_F32x16_VEC_ZERO(float32x16_t *target)
|
||||||
: "zmm8", "memory");
|
: "zmm8", "memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. optionally clear the sum before starting.
|
// Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. Optionally clear the sum before starting.
|
||||||
inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x16_t *mvec2, float32x16_t *sumvec, size_t iterations, int clear)
|
inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x16_t *mvec2, float32x16_t *sumvec, size_t iterations, int clear)
|
||||||
{
|
{
|
||||||
uint8_t zero = 0;
|
uint8_t zero = 0;
|
||||||
|
@ -59,8 +59,8 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x
|
||||||
"vprefetchnta\t(%%r10)\n\t"
|
"vprefetchnta\t(%%r10)\n\t"
|
||||||
"vprefetchnta\t(%%r12)\n\t"
|
"vprefetchnta\t(%%r12)\n\t"
|
||||||
"5:\n\t"
|
"5:\n\t"
|
||||||
"cmp\t$3,\t%%r8\n\t" // Compare iterations to three.
|
"cmp\t$4,\t%%r8\n\t" // Compare iterations to four.
|
||||||
"jnae\t6f\n\t" // If there are not three iterations left, jump to label 6.
|
"jnae\t6f\n\t" // If there are not four iterations left, jump to label 6.
|
||||||
"1:\n\t"
|
"1:\n\t"
|
||||||
"sub\t$3,\t%%r8\n\t" // Decrement iterations
|
"sub\t$3,\t%%r8\n\t" // Decrement iterations
|
||||||
"vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors.
|
"vmovaps\t\t(%%r10),\t%%zmm1\n\t" // Load two vectors.
|
||||||
|
@ -79,12 +79,15 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x
|
||||||
"vprefetch1\t576(%%r12)\n\t"
|
"vprefetch1\t576(%%r12)\n\t"
|
||||||
"vprefetch1\t704(%%r10)\n\t"
|
"vprefetch1\t704(%%r10)\n\t"
|
||||||
"vprefetch1\t704(%%r12)\n\t"
|
"vprefetch1\t704(%%r12)\n\t"
|
||||||
"add\t$192,\t%%r10\n\t" // Move to the next float32x16_t block (192 bytes ahead)
|
|
||||||
"add\t$192,\t%%r12\n\t"
|
|
||||||
"vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"cmp\t$3,\t%%r8\n\t" // Compare iteration count to three.
|
"vmovaps\t\t192(%%r10),\t%%zmm7\n\t" // Load two vectors.
|
||||||
"jge\t1b\n\t" // If there three or more iterations left, loop.
|
"vmovaps\t\t192(%%r12),\t%%zmm8\n\t"
|
||||||
"6:\n\t" // We know we are near the tail. handle 2, 1, and 0 cases.
|
"vfmadd231ps\t%%zmm7,\t%%zmm8,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
|
"add\t$256,\t%%r10\n\t" // Move to the next 4xfloat32x16_t block (256 bytes ahead)
|
||||||
|
"add\t$256,\t%%r12\n\t"
|
||||||
|
"cmp\t$4,\t%%r8\n\t" // Compare iteration count to four.
|
||||||
|
"jge\t1b\n\t" // If there are four or more iterations left, loop.
|
||||||
|
"6:\n\t" // We know we are near the tail. handle 3, 2, 1, and 0 cases.
|
||||||
"cmp\t$0,\t%%r8\n\t" // Compare iterations to zero
|
"cmp\t$0,\t%%r8\n\t" // Compare iterations to zero
|
||||||
"jz\t2f\n\t" // Jump to label 2 if zero (end of loop)
|
"jz\t2f\n\t" // Jump to label 2 if zero (end of loop)
|
||||||
"cmp\t$1,\t%%r8\n\t" // Compare iterations to one
|
"cmp\t$1,\t%%r8\n\t" // Compare iterations to one
|
||||||
|
@ -92,10 +95,14 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x
|
||||||
"vmovaps\t\t(%%r12),\t%%zmm2\n\t"
|
"vmovaps\t\t(%%r12),\t%%zmm2\n\t"
|
||||||
"vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm1,\t%%zmm2,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"je\t2f\n\t" // Jump to label 2 if one (end of loop)
|
"je\t2f\n\t" // Jump to label 2 if one (end of loop)
|
||||||
// No compare. we must be two.
|
"cmp\t$2,\t%%r8\n\t" // Compare iterations to two
|
||||||
"vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors.
|
"vmovaps\t\t64(%%r10),\t%%zmm3\n\t" // Load two vectors.
|
||||||
"vmovaps\t\t64(%%r12),\t%%zmm4\n\t"
|
"vmovaps\t\t64(%%r12),\t%%zmm4\n\t"
|
||||||
"vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add
|
"vfmadd231ps\t%%zmm3,\t%%zmm4,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
|
// No compare. we must be three.
|
||||||
|
"vmovaps\t\t64(%%r10),\t%%zmm5\n\t" // Load two vectors.
|
||||||
|
"vmovaps\t\t64(%%r12),\t%%zmm6\n\t"
|
||||||
|
"vfmadd231ps\t%%zmm5,\t%%zmm6,\t%%zmm0\n\t" // Perform a fused multiply add
|
||||||
"2:\n\t" // Label for loop end
|
"2:\n\t" // Label for loop end
|
||||||
"vmovnraps\t\t%%zmm0,\t(%[RES])\n\t" // Save our results.
|
"vmovnraps\t\t%%zmm0,\t(%[RES])\n\t" // Save our results.
|
||||||
: [RES] "+r" (sumvec)
|
: [RES] "+r" (sumvec)
|
||||||
|
@ -104,7 +111,7 @@ inline static void GGML_F32x16_VEC_FMA(const float32x16_t *mvec1, const float32x
|
||||||
[VEC2] "r" (mvec2),
|
[VEC2] "r" (mvec2),
|
||||||
[CLR] "r" (clear),
|
[CLR] "r" (clear),
|
||||||
[Z] "m" (zero)
|
[Z] "m" (zero)
|
||||||
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "cc", "memory", "r8", "r10", "r12");
|
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "memory", "r8", "r10", "r12");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. uses masks to handle just the last run-through.
|
// Multiply each item in mvec1 with the corresponding item in mvec2, adding the result to the corresponding item in sum. uses masks to handle just the last run-through.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue