[ggml-aarch64] use intrinsics in q4_0_4_4 gemv
This commit is contained in:
parent
a9e8a9a030
commit
102299e30c
1 changed files with 42 additions and 54 deletions
|
@ -620,62 +620,50 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
if (ggml_cpu_has_neon()) {
|
||||
const void * b_ptr = vx;
|
||||
const void * a_ptr = vy;
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
float * res_ptr = s;
|
||||
|
||||
__asm__ __volatile__(
|
||||
"movi v31.16b, #0x4\n"
|
||||
"movi v30.16b, #0xf0\n"
|
||||
"add %x[b_ptr], %x[b_ptr], #0x8\n"
|
||||
"1:" // Column loop
|
||||
"add x22, %x[a_ptr], #0x2\n"
|
||||
"movi v29.16b, #0x0\n"
|
||||
"mov x21, %x[nb]\n"
|
||||
"2:" // Block loop
|
||||
"ldr q28, [%x[b_ptr], #0x0]\n"
|
||||
"ldr q27, [x22, #0x0]\n"
|
||||
"movi v26.4s, #0x0\n"
|
||||
"sub x20, x22, #0x2\n"
|
||||
"ldr q25, [x22, #0x10]\n"
|
||||
"ldr q24, [%x[b_ptr], #0x10]\n"
|
||||
"sub x21, x21, #0x1\n"
|
||||
"add x22, x22, #0x22\n"
|
||||
"ldr q23, [%x[b_ptr], #0x20]\n"
|
||||
"ldr q22, [%x[b_ptr], #0x30]\n"
|
||||
"ld1r { v21.8h }, [x20]\n"
|
||||
"ldr q20, [%x[b_ptr], #-0x8]\n"
|
||||
"sshl v16.16b, v28.16b, v31.16b\n"
|
||||
"and v28.16b, v28.16b, v30.16b\n"
|
||||
"sshl v19.16b, v24.16b, v31.16b\n"
|
||||
"and v24.16b, v24.16b, v30.16b\n"
|
||||
"add %x[b_ptr], %x[b_ptr], #0x48\n"
|
||||
"sshl v18.16b, v23.16b, v31.16b\n"
|
||||
"and v23.16b, v23.16b, v30.16b\n"
|
||||
".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
|
||||
"sshl v17.16b, v22.16b, v31.16b\n"
|
||||
"and v22.16b, v22.16b, v30.16b\n"
|
||||
"fcvtl v21.4s, v21.4h\n"
|
||||
"fcvtl v16.4s, v20.4h\n"
|
||||
".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
|
||||
"fmul v16.4s, v16.4s, v21.4s\n"
|
||||
".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
|
||||
".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
|
||||
".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
|
||||
".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
|
||||
".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
|
||||
".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
|
||||
"scvtf v26.4s, v26.4s, #0x4\n"
|
||||
"fmla v29.4s, v26.4s, v16.4s\n"
|
||||
"cbnz x21, 2b\n"
|
||||
"sub %x[nc], %x[nc], #0x4\n"
|
||||
"str q29, [%x[res_ptr], #0x0]\n"
|
||||
"add %x[res_ptr], %x[res_ptr], #0x10\n"
|
||||
"cbnz %x[nc], 1b\n"
|
||||
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
|
||||
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
||||
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
|
||||
);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf = vdupq_n_f32(0);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
||||
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
||||
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
||||
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
||||
|
||||
int8x16_t b_0_hi = vreinterpretq_s8_u8(b_0 & 0xF0);
|
||||
int8x16_t b_0_lo = vreinterpretq_s8_u8(b_0 << 4);
|
||||
int8x16_t b_1_hi = vreinterpretq_s8_u8(b_1 & 0xF0);
|
||||
int8x16_t b_1_lo = vreinterpretq_s8_u8(b_1 << 4);
|
||||
int8x16_t b_2_hi = vreinterpretq_s8_u8(b_2 & 0xF0);
|
||||
int8x16_t b_2_lo = vreinterpretq_s8_u8(b_2 << 4);
|
||||
int8x16_t b_3_hi = vreinterpretq_s8_u8(b_3 & 0xF0);
|
||||
int8x16_t b_3_lo = vreinterpretq_s8_u8(b_3 << 4);
|
||||
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
||||
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
||||
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
||||
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
||||
float32x4_t d = a_d * b_d;
|
||||
|
||||
sumf = vmlaq_f32(sumf, d, vcvtq_n_f32_s32(sumi, 4));
|
||||
}
|
||||
|
||||
vst1q_f32(res_ptr + x * 4, sumf);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue