ggml : add AArch64 optimized GEMV and GEMM Q4 kernels (#5780)

* Arm AArch64: optimized GEMV and GEMM kernels for q4_0_q8_0, and q8_0_q8_0 quantization

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

* Arm AArch64: add copyright claim only to ggml-aarch64.cpp and ggml-aarch64.h files

* Arm AArch64: minor code refactoring for rebase

* Arm AArch64: minor code refactoring for resolving a build issue with cmake

* Arm AArch64: minor code refactoring to split the Q4_0_AARC64 type into three separate types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: minor code change for resolving a build issue with server-windows

* retrigger checks

* Arm AArch64: minor code changes for rebase

* Arm AArch64: minor changes to skip the pr#7433 vec_dot code for arm cpus with SVE VL not equal to 256 bits

* Arm AArch64: remove stale LLAMA_QKK_64 from CMakeLists.txt and delete build.zig

* Arm AArch64: add reference scalar gemm and gemv, and avoid dynamic memory allocations during quantization for Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: add multithreaded quantization support for the new types: Q4_0_4_4, Q4_0_4_8, and Q4_0_8_8

* Arm AArch64: minor code refactoring

* Arm AArch64: simplify logic for calling gemm and gemv functions in ggml_compute_forward_mul_mat

* Arm AArch64: minimize changes in ggml_compute_forward_mul_mat

* Arm AArch64: minor code refactoring, and add reference scalar code to quantize routines for new quant types

* Arm AArch64: minor code refactoring

* Arm AArch64: minor code refactoring

* Arm AArch64: minor code refactoring

* rebase on the latest master commit 3fd62a6 and adapt to the new directory structure

* Arm AArch64: remove a redundant comment

* Arm AArch64: add pragma in ggml-aarch64.c to turn -Woverlength-strings warning off

* Arm AArch64: use __aarch64__ check to guard 64-bit neon kernels

* Arm AArch64: update docs/build.md README to include compile time flags for buiilding the Q4_0_4_4 quant type
This commit is contained in:
Dibakar Gope 2024-07-10 07:14:51 -05:00 committed by GitHub
parent 83321c6958
commit 0f1a39f343
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2534 additions and 53 deletions

View file

@ -3814,43 +3814,47 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}
#endif
#if defined(__ARM_FEATURE_SVE)
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
if (svcntb() == QK8_0) {
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
// load x
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
// load x
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
// 4-bit -> 8-bit
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
// 4-bit -> 8-bit
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
// sub 8
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
// sub 8
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
// dot product
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
// dot product
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
return;
}
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
#elif defined(__ARM_NEON)
#endif
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -5422,31 +5426,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}
#endif
#if defined(__ARM_FEATURE_SVE)
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
if (svcntb() == QK8_0) {
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
assert(nb % 2 == 0); // TODO: handle odd nb
assert(nb % 2 == 0); // TODO: handle odd nb
for (int i = 0; i < nb; i += 2) {
const block_q8_0 * restrict x0 = &x[i + 0];
const block_q8_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
for (int i = 0; i < nb; i += 2) {
const block_q8_0 * restrict x0 = &x[i + 0];
const block_q8_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];
// load x
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
// load x
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
// load y
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
return;
}
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
#elif defined(__ARM_NEON)
#endif
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -14760,6 +14768,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
} \
}
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
for (size_t j = 0; j < (nr); ++j) { \
if (!validate_fp16(q[i].d[j], i)) { \
return false; \
} \
} \
}
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
if (type < 0 || type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
@ -14977,6 +14995,16 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
{
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
} break;
case GGML_TYPE_Q4_0_8_8:
{
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
} break;
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32: