diff --git a/ggml-phi-knc-dot_q5_K_q8_K.c b/ggml-phi-knc-dot_q5_K_q8_K.c index a6072f665..4f2ff837c 100644 --- a/ggml-phi-knc-dot_q5_K_q8_K.c +++ b/ggml-phi-knc-dot_q5_K_q8_K.c @@ -204,8 +204,10 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint "vprefetch0\t(%[SRC1])\n\t" // Issue our memory requests first thing. "vprefetch0\t(%[SRC4])\n\t" "vprefetchenta\t(%[DST])\n\t" - "mov\t%[SRC4],\t%%r9\n\t" // Load the address of the head of our 4-bit list. "mov\t%[DST],\t%%r8\n\t" // Load the address of the head of our destination list. + "mov\t%[SRC4],\t%%r9\n\t" // Load the address of the head of our 4-bit list into r9, for vloadunpackld. + "mov\t%[SRC4],\t%%r10\n\t" // Load the address of the head of our 4-bit list into r10-r11, for vloadunpackhd. + "mov\t%[SRC4],\t%%r11\n\t" "mov\t$0,%%ecx\n\t" // Initialize our counter. "vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm0\n\t" // Load our mask. "vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm1\n\t" // Load the bit we want to add (conditionally). @@ -220,13 +222,13 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint "vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // Test to see if our selected bit is set. "vloadunpackld\t\t(%%r9)%{uint8%},\t%%zmm5\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value. - "vloadunpackhd\t\t16(%%r9)%{uint8%},\t%%zmm5\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value. + "vloadunpackhd\t\t16(%%r10)%{uint8%},\t%%zmm5\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value. "vpandd\t%%zmm0,\t%%zmm5,\t%%zmm6\n\t" // Apply a mask, storing the first set of four bits into a vector. "vpord\t%%zmm1,%%zmm6,%%zmm6%{%%k1%}\n\t" // Turn on bit 5 for all values that passed the prior test. "vmovdqa32\t\t%%zmm6%{uint8%},\t(%%r8)\n\t" // Save our result. "vloadunpackld\t\t16(%%r9)%{uint8%},\t%%zmm7\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value. - "vloadunpackhd\t\t32(%%r9)%{uint8%},\t%%zmm7\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value. + "vloadunpackhd\t\t32(%%r11)%{uint8%},\t%%zmm7\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value. "vprefetch1\t32(%%r9)\n\t" // Pull the next set of 4 bit sequences into the L2 cache. "vpandd\t%%zmm0,\t%%zmm7,\t%%zmm8\n\t" // Apply a mask, storing the next set of four bits into a vector. "vpord\t%%zmm1,%%zmm8,%%zmm8%{%%k2%}\n\t" // Turn on bit 5 for all values that passed the prior test. @@ -237,8 +239,8 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint "vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // Select the next bit to test for. - "vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // Perform our test. - "vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // Perform our test. + "vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // Test to see if our selected bit is set. + "vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // Test to see if our selected bit is set. "vpsrld\t$4,\t%%zmm5,\t%%zmm6\n\t" // Load our even 4 bit sequence. "vpsrld\t$4,\t%%zmm7,\t%%zmm8\n\t" // Load our next even 4 bit sequence. "vpord\t%%zmm1,%%zmm6,%%zmm6%{%%k1%}\n\t" // Turn on bit 5 for all values that passed the prior test. @@ -252,8 +254,10 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint "vprefetch0\t32(%%r9)\n\t" "vprefetch1\t96(%%r9)\n\t" "vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // Select the next bit to test for. - "add\t$32,\t%%r9\n\t" "add\t$32,\t%%r8\n\t" + "add\t$32,\t%%r9\n\t" + "add\t$32,\t%%r10\n\t" + "add\t$32,\t%%r11\n\t" "jmp\t1b\n\t" "2:" : [DST] "+r" (dst) @@ -262,7 +266,7 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint [MASK] "m" (lowmask), [M] "m" (m), [BIT5] "m" (bit5) - : "zmm0", "zmm1", "zmm2", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "k1", "k2", "r8", "r9", "memory"); + : "zmm0", "zmm1", "zmm2", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "k1", "k2", "r8", "r9", "r10", "r11", "memory"); } // A function for getting the dot product of two vectors, one of 5 bit resolution, and one of 8.