diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 1a0b1da1f..0a8f9b197 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -667,14 +667,31 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * float * res_ptr = s; for (int x = 0; x < nc / ncols_interleaved; x++) { + // %x[nc] : loop control + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); float32x4_t sumf = vdupq_n_f32(0); + // v29 = sumf + for (int l = 0; l < nb; l++) { + // x21 : loop control + + // x22 = a_ptr[l].qs + // %x[b_ptr] = b_ptr[l].qs + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + // (v27, v25) = (a_0, a_1) + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + // (v28, v24, v23, v22) = (b_0, b_1, b_2, b_3) + + float16x4_t b_d_half = vld1_f16((const float16_t *)b_ptr[l].d); + // v20 = b_d_half int8x16_t b_0_hi = vreinterpretq_s8_u8(b_0 & 0xF0); int8x16_t b_0_lo = vreinterpretq_s8_u8(b_0 << 4); @@ -684,11 +701,13 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * int8x16_t b_2_lo = vreinterpretq_s8_u8(b_2 << 4); int8x16_t b_3_hi = vreinterpretq_s8_u8(b_3 & 0xF0); int8x16_t b_3_lo = vreinterpretq_s8_u8(b_3 << 4); - - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + // (v16, v28) = (b_0_lo, b_0_hi) + // (v19, v24) = (b_0_lo, b_0_hi) + // (v18, v23) = (b_0_lo, b_0_hi) + // (v17, v22) = (b_0_lo, b_0_hi) int32x4_t sumi = vdupq_n_s32(0); + // v26 = sumi sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); @@ -697,15 +716,21 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); - + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); - float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + // v21 = a_d + + float32x4_t b_d = vcvt_f32_f16(b_d_half); + // v16 = b_d + float32x4_t d = a_d * b_d; + // v16 = d sumf = vmlaq_f32(sumf, d, vcvtq_n_f32_s32(sumi, 4)); } vst1q_f32(res_ptr + x * 4, sumf); + // %x[res_ptr] = res_ptr + x * 4 } return; } @@ -1174,7 +1199,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); - + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); float32x4_t d = a_d * b_d; @@ -1236,7 +1261,97 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) if (ggml_cpu_has_neon()) { - for (int y = 0; y < nr / 4; y++) { +#define UNROLL_FACTOR 4 + int y = 0; + for (; y + UNROLL_FACTOR <= nr / 4; y += UNROLL_FACTOR) { + const block_q8_0x4 * a_ptr[UNROLL_FACTOR]; + for (int z = 0; z < UNROLL_FACTOR; z++) { + a_ptr[z] = (const block_q8_0x4 *) vy + ((y + z) * nb); + } + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + float32x4_t sumf[UNROLL_FACTOR][4]; + for (int z = 0; z < UNROLL_FACTOR; z ++) { + for (int m = 0; m < 4; m++) { + sumf[z][m] = vdupq_n_f32(0); + } + } + // (v15, v19, v18, v14) = sumf[0][0, 1, 2, 3] + // (v11, v13, v23, v16) = sumf[1][0, 1, 2, 3] + // (v27, v7, v0, v4 ) = sumf[2][0, 1, 2, 3] + // (v5, v21, v8, v1 ) = sumf[3][0, 1, 2, 3] + + for (int l = 0; l < nb; l++) { + // x24 : loop control + + // x28 = b_ptr[l].qs + // (x25, x23, x22, x21) = a_ptr[0, 1, 2, 3][l].qs + + int8x16_t b_hi[4], b_lo[4]; + for (int k = 0; k < 4; k++) { + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + b_hi[k] = vreinterpretq_s8_u8(b & 0xF0); + b_lo[k] = vreinterpretq_s8_u8(b << 4); + } + // (v12, v3) = (b_lo[0], b_hi[0]) + // (v31, v22) = (b_lo[1], b_hi[1]) + // (v6, v27) = (b_lo[2], b_hi[2]) + // (v28, v30) = (b_lo[3], b_hi[3]) + + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + // v17 = b_d + + // unroll in ASM + for (int z = 0; z < UNROLL_FACTOR; z++) { + int32x4_t sumi[4]; + for (int m = 0; m < 4; m++) { + sumi[m] = vdupq_n_s32(0); + } + // (v10, v29, v9, v20) = sumi[0, 1, 2, 3] (z = 0) + // (v9, v29, v20, v2) = sumi[0, 1, 2, 3] (z = 1) + // (v20, v10, v26, v2) = sumi[0, 1, 2, 3] (z = 2) + // (v26, v10, v2, v19) = sumi[0, 1, 2, 3] (z = 3) + + for (int k = 0; k < 4; k++) { + int8x16_t a0 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 0); + sumi[0] = vdotq_laneq_s32(sumi[0], b_lo[k], a0, 0); + sumi[1] = vdotq_laneq_s32(sumi[1], b_lo[k], a0, 1); + sumi[2] = vdotq_laneq_s32(sumi[2], b_lo[k], a0, 2); + sumi[3] = vdotq_laneq_s32(sumi[3], b_lo[k], a0, 3); + } + for (int k = 0; k < 4; k++) { + int8x16_t a1 = vld1q_s8(a_ptr[z][l].qs + 16 * k + 64); + sumi[0] = vdotq_laneq_s32(sumi[0], b_hi[k], a1, 0); + sumi[1] = vdotq_laneq_s32(sumi[1], b_hi[k], a1, 1); + sumi[2] = vdotq_laneq_s32(sumi[2], b_hi[k], a1, 2); + sumi[3] = vdotq_laneq_s32(sumi[3], b_hi[k], a1, 3); + } + + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[z][l].d)); + // (v2, v26, v29, v20) = a_d (z = 0, 1, 2, 3) + + sumf[z][0] = vmlaq_f32(sumf[z][0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi[0], 4)); + sumf[z][1] = vmlaq_f32(sumf[z][1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_n_f32_s32(sumi[1], 4)); + sumf[z][2] = vmlaq_f32(sumf[z][2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_n_f32_s32(sumi[2], 4)); + sumf[z][3] = vmlaq_f32(sumf[z][3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi[3], 4)); + } + + } + + for (int z = 0; z < UNROLL_FACTOR; z++) { + for (int m = 0; m < 4; m++) { + vst1q_f32(s + ((y + z) * 4 + m) * bs + x * 4, sumf[z][m]); + } + } + } + } +#undef UNROLL_FACTOR + + for (; y < nr / 4; y++) { + // x10 : loop control + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); @@ -1245,32 +1360,68 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * for (int m = 0; m < 4; m++) { sumf[m] = vdupq_n_f32(0); } - + // (v15, v19, v18, v14) = sumf[0, 1, 2, 3] + for (int l = 0; l < nb; l++) { + // x21 : loop control + + // x25 = a_ptr[l].qs + // x24 = b_ptr[l].qs + + int8x16_t a_0[4], a_1[4]; + a_0[0] = vld1q_s8(a_ptr[l].qs + 0); + a_0[1] = vld1q_s8(a_ptr[l].qs + 16); + a_0[2] = vld1q_s8(a_ptr[l].qs + 32); + a_0[3] = vld1q_s8(a_ptr[l].qs + 48); + a_1[0] = vld1q_s8(a_ptr[l].qs + 64); + a_1[1] = vld1q_s8(a_ptr[l].qs + 80); + a_1[2] = vld1q_s8(a_ptr[l].qs + 96); + a_1[3] = vld1q_s8(a_ptr[l].qs + 112); + // (v5, v26) = (a_0[0], a_1[0]) + // (v2, v25) = (a_0[0], a_1[0]) + // (v31, v24) = (a_0[0], a_1[0]) + // (v27, v16) = (a_0[0], a_1[0]) + + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + // (v7, v3, v13, v28) = (b_0, b_1, b_2, b_3) + + int8x16_t b_lo[4], b_hi[4]; + b_hi[0] = vreinterpretq_s8_u8(b_0 & 0xF0); + b_lo[0] = vreinterpretq_s8_u8(b_0 << 4); + b_hi[1] = vreinterpretq_s8_u8(b_1 & 0xF0); + b_lo[1] = vreinterpretq_s8_u8(b_1 << 4); + b_hi[2] = vreinterpretq_s8_u8(b_2 & 0xF0); + b_lo[2] = vreinterpretq_s8_u8(b_2 << 4); + b_hi[3] = vreinterpretq_s8_u8(b_3 & 0xF0); + b_lo[3] = vreinterpretq_s8_u8(b_3 << 4); + // (v20, v7) = (b_lo[0], b_hi[0]) + // (v17, v3) = (b_lo[1], b_hi[1]) + // (v22, v13) = (b_lo[2], b_hi[2]) + // (v9, v28) = (b_lo[3], b_hi[3]) + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + // v12 = a_d float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + // v21 = b_d int32x4_t sumi_0 = vdupq_n_s32(0); int32x4_t sumi_1 = vdupq_n_s32(0); int32x4_t sumi_2 = vdupq_n_s32(0); int32x4_t sumi_3 = vdupq_n_s32(0); + // (v4, v1, v0, v30) = (sumi_0, sumi_1, sumi_2, sumi_3) for (int k = 0; k < 4; k++) { - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); - - uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); - int8x16_t b_hi = vreinterpretq_s8_u8(b & 0xF0); - int8x16_t b_lo = vreinterpretq_s8_u8(b << 4); - - sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); - sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo[k], a_0[k], 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo[k], a_0[k], 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo[k], a_0[k], 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo[k], a_0[k], 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi[k], a_1[k], 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi[k], a_1[k], 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi[k], a_1[k], 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi[k], a_1[k], 3); } sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_n_f32_s32(sumi_0, 4)); @@ -1279,6 +1430,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_n_f32_s32(sumi_3, 4)); } + // NOTE: asm version has addition code to handle `nr` is not multiple of 4 for (int m = 0; m < 4; m++) { vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); } @@ -3230,7 +3382,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void for (int m = 0; m < 4; m++) { sumf[m] = vdupq_n_f32(0); } - + for (int l = 0; l < nb; l++) { float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); @@ -3244,7 +3396,7 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); - uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);