[ggml-aarch64] impl the same logic as the ASM version in q4_0_4_4 gemm/gemv

This commit is contained in:
Shupei Fan 2024-11-10 21:33:40 +08:00
parent 32e0862a7e
commit c7a54d1f2b

View file

@ -667,14 +667,31 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
float * res_ptr = s;
for (int x = 0; x < nc / ncols_interleaved; x++) {
// %x[nc] : loop control
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
float32x4_t sumf = vdupq_n_f32(0);
// v29 = sumf
for (int l = 0; l < nb; l++) {
// x21 : loop control
// x22 = a_ptr[l].qs
// %x[b_ptr] = b_ptr[l].qs
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
// (v27, v25) = (a_0, a_1)
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
// (v28, v24, v23, v22) = (b_0, b_1, b_2, b_3)
float16x4_t b_d_half = vld1_f16((const float16_t *)b_ptr[l].d);
// v20 = b_d_half
int8x16_t b_0_hi = vreinterpretq_s8_u8(b_0 & 0xF0);
int8x16_t b_0_lo = vreinterpretq_s8_u8(b_0 << 4);
@ -684,11 +701,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
int8x16_t b_2_lo = vreinterpretq_s8_u8(b_2 << 4);
int8x16_t b_3_hi = vreinterpretq_s8_u8(b_3 & 0xF0);
int8x16_t b_3_lo = vreinterpretq_s8_u8(b_3 << 4);
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
// (v16, v28) = (b_0_lo, b_0_hi)
// (v19, v24) = (b_0_lo, b_0_hi)
// (v18, v23) = (b_0_lo, b_0_hi)
// (v17, v22) = (b_0_lo, b_0_hi)
int32x4_t sumi = vdupq_n_s32(0);
// v26 = sumi
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
@ -697,15 +716,21 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
// v21 = a_d
float32x4_t b_d = vcvt_f32_f16(b_d_half);
// v16 = b_d
float32x4_t d = a_d * b_d;
// v16 = d
sumf = vmlaq_f32(sumf, d, vcvtq_n_f32_s32(sumi, 4));
}
vst1q_f32(res_ptr + x * 4, sumf);
// %x[res_ptr] = res_ptr + x * 4
}
return;
}
@ -1174,7 +1199,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
float32x4_t d = a_d * b_d;
@ -1236,7 +1261,97 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
if (ggml_cpu_has_neon()) {
for (int y = 0; y < nr / 4; y++) {
#define UNROLL_FACTOR 4
int y = 0;
for (; y + UNROLL_FACTOR <= nr / 4; y += UNROLL_FACTOR) {
const block_q8_0x4 * a_ptr[UNROLL_FACTOR];
for (int z = 0; z < UNROLL_FACTOR; z++) {
a_ptr[z] = (const block_q8_0x4 *) vy + ((y + z) * nb);
}
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
float32x4_t sumf[UNROLL_FACTOR][4];
for (int z = 0; z < UNROLL_FACTOR; z ++) {
for (int m = 0; m < 4; m++) {
sumf[z][m] = vdupq_n_f32(0);
}
}
// (v15, v19, v18, v14) = sumf[0][0, 1, 2, 3]
// (v11, v13, v23, v16) = sumf[1][0, 1, 2, 3]
// (v27, v7, v0, v4 ) = sumf[2][0, 1, 2, 3]
// (v5, v21, v8, v1 ) = sumf[3][0, 1, 2, 3]
for (int l = 0; l < nb; l++) {
// x24 : loop control
// x28 = b_ptr[l].qs
// (x25, x23, x22, x21) = a_ptr[0, 1, 2, 3][l].qs
int8x16_t b_hi[4], b_lo[4];
for (int k = 0; k < 4; k++) {
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
b_hi[k] = vreinterpretq_s8_u8(b & 0xF0);
b_lo[k] = vreinterpretq_s8_u8(b << 4);
}
// (v12, v3) = (b_lo[0], b_hi[0])
// (v31, v22) = (b_lo[1], b_hi[1])
// (v6, v27) = (b_lo[2], b_hi[2])
// (v28, v30) = (b_lo[3], b_hi[3])
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
// v17 = b_d
// unroll in ASM
for (int z = 0; z < UNROLL_FACTOR; z++) {
int32x4_t sumi[4];
for (int m = 0; m < 4; m++) {
sumi[m] = vdupq_n_s32(0);
}
// (v10, v29, v9, v20) = sumi[0, 1, 2, 3] (z = 0)
// (v9, v29, v20, v2) = sumi[0, 1, 2, 3] (z = 1)
// (v20, v10, v26, v2) = sumi[0, 1, 2, 3] (z = 2)
// (v26, v10, v2, v19) = sumi[0, 1, 2, 3] (z = 3)
for (int k = 0; k < 4; k++) {
int8x16_t a0 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 0);
sumi[0] = vdotq_laneq_s32(sumi[0], b_lo[k], a0, 0);
sumi[1] = vdotq_laneq_s32(sumi[1], b_lo[k], a0, 1);
sumi[2] = vdotq_laneq_s32(sumi[2], b_lo[k], a0, 2);
sumi[3] = vdotq_laneq_s32(sumi[3], b_lo[k], a0, 3);
}
for (int k = 0; k < 4; k++) {
int8x16_t a1 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 64);
sumi[0] = vdotq_laneq_s32(sumi[0], b_hi[k], a1, 0);
sumi[1] = vdotq_laneq_s32(sumi[1], b_hi[k], a1, 1);
sumi[2] = vdotq_laneq_s32(sumi[2], b_hi[k], a1, 2);
sumi[3] = vdotq_laneq_s32(sumi[3], b_hi[k], a1, 3);
}
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[z][l].d));
// (v2, v26, v29, v20) = a_d (z = 0, 1, 2, 3)
sumf[z][0] = vmlaq_f32(sumf[z][0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi[0], 4));
sumf[z][1] = vmlaq_f32(sumf[z][1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_n_f32_s32(sumi[1], 4));
sumf[z][2] = vmlaq_f32(sumf[z][2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_n_f32_s32(sumi[2], 4));
sumf[z][3] = vmlaq_f32(sumf[z][3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi[3], 4));
}
}
for (int z = 0; z < UNROLL_FACTOR; z++) {
for (int m = 0; m < 4; m++) {
vst1q_f32(s + ((y + z) * 4 + m) * bs + x * 4, sumf[z][m]);
}
}
}
}
#undef UNROLL_FACTOR
for (; y < nr / 4; y++) {
// x10 : loop control
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
@ -1245,32 +1360,68 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
for (int m = 0; m < 4; m++) {
sumf[m] = vdupq_n_f32(0);
}
// (v15, v19, v18, v14) = sumf[0, 1, 2, 3]
for (int l = 0; l < nb; l++) {
// x21 : loop control
// x25 = a_ptr[l].qs
// x24 = b_ptr[l].qs
int8x16_t a_0[4], a_1[4];
a_0[0] = vld1q_s8(a_ptr[l].qs + 0);
a_0[1] = vld1q_s8(a_ptr[l].qs + 16);
a_0[2] = vld1q_s8(a_ptr[l].qs + 32);
a_0[3] = vld1q_s8(a_ptr[l].qs + 48);
a_1[0] = vld1q_s8(a_ptr[l].qs + 64);
a_1[1] = vld1q_s8(a_ptr[l].qs + 80);
a_1[2] = vld1q_s8(a_ptr[l].qs + 96);
a_1[3] = vld1q_s8(a_ptr[l].qs + 112);
// (v5, v26) = (a_0[0], a_1[0])
// (v2, v25) = (a_0[0], a_1[0])
// (v31, v24) = (a_0[0], a_1[0])
// (v27, v16) = (a_0[0], a_1[0])
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
// (v7, v3, v13, v28) = (b_0, b_1, b_2, b_3)
int8x16_t b_lo[4], b_hi[4];
b_hi[0] = vreinterpretq_s8_u8(b_0 & 0xF0);
b_lo[0] = vreinterpretq_s8_u8(b_0 << 4);
b_hi[1] = vreinterpretq_s8_u8(b_1 & 0xF0);
b_lo[1] = vreinterpretq_s8_u8(b_1 << 4);
b_hi[2] = vreinterpretq_s8_u8(b_2 & 0xF0);
b_lo[2] = vreinterpretq_s8_u8(b_2 << 4);
b_hi[3] = vreinterpretq_s8_u8(b_3 & 0xF0);
b_lo[3] = vreinterpretq_s8_u8(b_3 << 4);
// (v20, v7) = (b_lo[0], b_hi[0])
// (v17, v3) = (b_lo[1], b_hi[1])
// (v22, v13) = (b_lo[2], b_hi[2])
// (v9, v28) = (b_lo[3], b_hi[3])
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
// v12 = a_d
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
// v21 = b_d
int32x4_t sumi_0 = vdupq_n_s32(0);
int32x4_t sumi_1 = vdupq_n_s32(0);
int32x4_t sumi_2 = vdupq_n_s32(0);
int32x4_t sumi_3 = vdupq_n_s32(0);
// (v4, v1, v0, v30) = (sumi_0, sumi_1, sumi_2, sumi_3)
for (int k = 0; k < 4; k++) {
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
int8x16_t b_hi = vreinterpretq_s8_u8(b & 0xF0);
int8x16_t b_lo = vreinterpretq_s8_u8(b << 4);
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo[k], a_0[k], 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo[k], a_0[k], 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo[k], a_0[k], 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo[k], a_0[k], 3);
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi[k], a_1[k], 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi[k], a_1[k], 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi[k], a_1[k], 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi[k], a_1[k], 3);
}
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi_0, 4));
@ -1279,6 +1430,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi_3, 4));
}
// NOTE: asm version has addition code to handle `nr` is not multiple of 4
for (int m = 0; m < 4; m++) {
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
}
@ -3230,7 +3382,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
for (int m = 0; m < 4; m++) {
sumf[m] = vdupq_n_f32(0);
}
for (int l = 0; l < nb; l++) {
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
@ -3244,7 +3396,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);