do 2 rounds of 4, instead of 4 rounds of 2. and properly offset unalligned reads across a 64 byte boundary.
This commit is contained in:
parent
7925fb1f64
commit
bd22e9d28a
1 changed files with 61 additions and 4 deletions
|
@ -207,12 +207,30 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint
|
||||||
"vprefetchenta\t(%[DST])\n\t"
|
"vprefetchenta\t(%[DST])\n\t"
|
||||||
"mov\t%[DST],\t%%r8\n\t" // Load the address of the head of our destination list.
|
"mov\t%[DST],\t%%r8\n\t" // Load the address of the head of our destination list.
|
||||||
"mov\t%[SRC4],\t%%r9\n\t" // Load the address of the head of our 4-bit list into r9, for vloadunpackld.
|
"mov\t%[SRC4],\t%%r9\n\t" // Load the address of the head of our 4-bit list into r9, for vloadunpackld.
|
||||||
"mov\t%[SRC4],\t%%r10\n\t" // Load the address of the head of our 4-bit list into r10-r11, for vloadunpackhd.
|
"mov\t%[SRC4],\t%%r10\n\t" // Load the address of the head of our 4-bit list into r10-r13, for vloadunpackhd.
|
||||||
"mov\t%[SRC4],\t%%r11\n\t"
|
"mov\t%[SRC4],\t%%r11\n\t"
|
||||||
"mov\t%[SRC4],\t%%r12\n\t"
|
"mov\t%[SRC4],\t%%r12\n\t"
|
||||||
"mov\t%[SRC4],\t%%r13\n\t"
|
"mov\t%[SRC4],\t%%r13\n\t"
|
||||||
"mov\t%[OFFSET],\t%%r14\n\t"
|
"mov\t%[OFFSET],\t%%r14\n\t"
|
||||||
"mov\t$0,%%ecx\n\t" // Initialize our counter.
|
"mov\t$0,%%ecx\n\t" // Initialize our counter.
|
||||||
|
"cmp\t$32,%%r14\n\t" // Examine OFFSET, and decide which (if any) of the vloadunpackhd invocations needs to be increased by 64.
|
||||||
|
"jl\t20f\n\t"
|
||||||
|
"cmp\t$48,%%r14\n\t"
|
||||||
|
"jl\t21f\n\t"
|
||||||
|
"add\t$64,%%r10\n\t" // Greater than 47.
|
||||||
|
"jmp\t24f\n\t"
|
||||||
|
"21:\n\t"
|
||||||
|
"add\t$64,%%r11\n\t" // Between 48 and 31.
|
||||||
|
"jmp\t24f\n\t"
|
||||||
|
"20:\n\t" // Less than 32...
|
||||||
|
"cmp\t$16,%%r14\n\t"
|
||||||
|
"jz\t24f\n\t" // Zero.
|
||||||
|
"jl\t23f\n\t"
|
||||||
|
"add\t$64,%%r12\n\t" // Between 32 and 15.
|
||||||
|
"jmp\t24f\n\t"
|
||||||
|
"23:\n\t"
|
||||||
|
"add\t$64,%%r13\n\t" // Between 16 and zero.
|
||||||
|
"24:\n\t"
|
||||||
"vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm0\n\t" // Load our mask.
|
"vpbroadcastd\t%[MASK]%{uint8%},\t%%zmm0\n\t" // Load our mask.
|
||||||
"vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm1\n\t" // Load the bit we want to add (conditionally).
|
"vpbroadcastd\t%[BIT5]%{uint8},\t%%zmm1\n\t" // Load the bit we want to add (conditionally).
|
||||||
"vpbroadcastd\t%[M]%{uint8%},\t%%zmm2\n\t" // Select which bit we want to test for. Start with bit 1.
|
"vpbroadcastd\t%[M]%{uint8%},\t%%zmm2\n\t" // Select which bit we want to test for. Start with bit 1.
|
||||||
|
@ -239,8 +257,6 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint
|
||||||
"vmovdqa32\t\t%%zmm8%{uint8%},\t16(%%r8)\n\t" // Save our result.
|
"vmovdqa32\t\t%%zmm8%{uint8%},\t16(%%r8)\n\t" // Save our result.
|
||||||
|
|
||||||
"add\t$32,\t%%r8\n\t"
|
"add\t$32,\t%%r8\n\t"
|
||||||
"cmp\t$4,\t%%ecx\n\t"
|
|
||||||
|
|
||||||
"vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // Select the next bit to test for.
|
"vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // Select the next bit to test for.
|
||||||
|
|
||||||
"vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // Test to see if our selected bit is set.
|
"vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // Test to see if our selected bit is set.
|
||||||
|
@ -254,6 +270,47 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint
|
||||||
"vmovdqa32\t\t%%zmm8%{uint8%},\t16(%%r8)\n\t" // Save our result.
|
"vmovdqa32\t\t%%zmm8%{uint8%},\t16(%%r8)\n\t" // Save our result.
|
||||||
"vprefetchenta\t32(%%r8)\n\t"
|
"vprefetchenta\t32(%%r8)\n\t"
|
||||||
|
|
||||||
|
"vprefetch0\t32(%%r9)\n\t"
|
||||||
|
"vprefetch1\t96(%%r9)\n\t"
|
||||||
|
"add\t$32,\t%%r8\n\t"
|
||||||
|
"add\t$32,\t%%r9\n\t"
|
||||||
|
"add\t$32,\t%%r10\n\t"
|
||||||
|
"add\t$32,\t%%r11\n\t"
|
||||||
|
"add\t$32,\t%%r12\n\t"
|
||||||
|
"add\t$32,\t%%r13\n\t"
|
||||||
|
"vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // Select the next bit to test for.
|
||||||
|
|
||||||
|
"vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // Test to see if our selected bit is set.
|
||||||
|
"vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // Test to see if our selected bit is set.
|
||||||
|
|
||||||
|
"vloadunpackld\t\t(%%r9)%{uint8%},\t%%zmm9\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||||
|
"vloadunpackhd\t\t(%%r12)%{uint8%},\t%%zmm9\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||||
|
"vpandd\t%%zmm0,\t%%zmm9,\t%%zmm10\n\t" // Apply a mask, storing the first set of four bits into a vector.
|
||||||
|
"vpord\t%%zmm1,%%zmm10,%%zmm10%{%%k1%}\n\t" // Turn on bit 5 for all values that passed the prior test.
|
||||||
|
"vmovdqa32\t\t%%zmm10%{uint8%},\t(%%r8)\n\t" // Save our result.
|
||||||
|
|
||||||
|
"vloadunpackld\t\t16(%%r9)%{uint8%},\t%%zmm11\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||||
|
"vloadunpackhd\t\t16(%%r13)%{uint8%},\t%%zmm11\n\t" // Load our odd 4 bit sequences. note that it loads two 4 bit sequences into each zmm value.
|
||||||
|
"vprefetch1\t32(%%r9)\n\t" // Pull the next set of 4 bit sequences into the L2 cache.
|
||||||
|
"vpandd\t%%zmm0,\t%%zmm11,\t%%zmm12\n\t" // Apply a mask, storing the next set of four bits into a vector.
|
||||||
|
"vpord\t%%zmm1,%%zmm12,%%zmm12%{%%k2%}\n\t" // Turn on bit 5 for all values that passed the prior test.
|
||||||
|
"vmovdqa32\t\t%%zmm12%{uint8%},\t16(%%r8)\n\t" // Save our result.
|
||||||
|
|
||||||
|
"add\t$32,\t%%r8\n\t"
|
||||||
|
"cmp\t$2,\t%%ecx\n\t"
|
||||||
|
"vpslld\t$1,\t%%zmm2,\t%%zmm2\n\t" // Select the next bit to test for.
|
||||||
|
|
||||||
|
"vptestmd\t%%zmm3,\t%%zmm2,\t%%k1\n\t" // Test to see if our selected bit is set.
|
||||||
|
"vptestmd\t%%zmm4,\t%%zmm2,\t%%k2\n\t" // Test to see if our selected bit is set.
|
||||||
|
|
||||||
|
"vpsrld\t$4,\t%%zmm9,\t%%zmm10\n\t" // Load our even 4 bit sequence.
|
||||||
|
"vpsrld\t$4,\t%%zmm11,\t%%zmm12\n\t" // Load our next even 4 bit sequence.
|
||||||
|
"vpord\t%%zmm1,%%zmm10,%%zmm10%{%%k1%}\n\t" // Turn on bit 5 for all values that passed the prior test.
|
||||||
|
"vpord\t%%zmm1,%%zmm12,%%zmm12%{%%k2%}\n\t" // Turn on bit 5 for all values that passed the prior test.
|
||||||
|
"vmovdqa32\t\t%%zmm10%{uint8%},\t(%%r8)\n\t" // Save our result.
|
||||||
|
"vmovdqa32\t\t%%zmm12%{uint8%},\t16(%%r8)\n\t" // Save our result.
|
||||||
|
"vprefetchenta\t32(%%r8)\n\t"
|
||||||
|
|
||||||
"je\t2f\n\t"
|
"je\t2f\n\t"
|
||||||
|
|
||||||
"vprefetch0\t32(%%r9)\n\t"
|
"vprefetch0\t32(%%r9)\n\t"
|
||||||
|
@ -274,7 +331,7 @@ void GGML_5bit_Unpack_Unaligned (const uint8x16_t * q4, const uint8_t * q1, uint
|
||||||
[MASK] "m" (lowmask),
|
[MASK] "m" (lowmask),
|
||||||
[M] "m" (m),
|
[M] "m" (m),
|
||||||
[BIT5] "m" (bit5)
|
[BIT5] "m" (bit5)
|
||||||
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "cc", "ecx", "k1", "k2", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "memory");
|
: "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "cc", "ecx", "k1", "k2", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "memory");
|
||||||
}
|
}
|
||||||
|
|
||||||
// A function for getting the dot product of two vectors, one of 5 bit resolution, and one of 8.
|
// A function for getting the dot product of two vectors, one of 5 bit resolution, and one of 8.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue