Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

This commit is contained in:
Dibakar Gope 2024-04-25 03:57:15 +00:00 committed by Dibakar Gope
parent 81215ff43a
commit 6c8d8266b1
2 changed files with 39 additions and 27 deletions

View file

@ -3309,41 +3309,37 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) { if (!quant_weights) {
//quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
//return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
int nrows_interleaved, blocklen_per_row; int nrows_interleaved, blocklen_per_row;
typedef block_q4_0x8 block_q4_0xn;
typedef block_q4_0xn (*make_block_q4_0xn_t)(const block_q4_0 *, unsigned int, unsigned int);
make_block_q4_0xn_t make_block_q4_0xn = make_block_q4_0x8;
if (ggml_cpu_has_sve() && (svcntw() == 8)) { #if defined(__ARM_FEATURE_SVE)
if (svcntw() == 8) {
nrows_interleaved = 8; nrows_interleaved = 8;
blocklen_per_row = 8; blocklen_per_row = 8;
typedef block_q4_0x8 block_q4_0xn;
make_block_q4_0xn = make_block_q4_0x8;
} }
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
nrows_interleaved = 4; nrows_interleaved = 4;
blocklen_per_row = 8; blocklen_per_row = 8;
typedef block_q4_0x4 block_q4_0xn;
make_block_q4_0xn = make_block_q4_0x4;
}
else if (ggml_cpu_has_neon()) {
nrows_interleaved = 4;
blocklen_per_row = 4;
typedef block_q4_0x4 block_q4_0xn;
make_block_q4_0xn = make_block_q4_0x4;
}
else {
assert(false);
} }
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
nrows_interleaved = 4;
blocklen_per_row = 8;
#elif defined(__ARM_NEON)
nrows_interleaved = 4;
blocklen_per_row = 4;
#endif
assert(n_per_row % QK4_0 == 0); assert(n_per_row % QK4_0 == 0);
const int nb = n_per_row / QK4_0; const int nb = n_per_row / QK4_0;
block_q4_0xn * out_ptr_B = (block_q4_0xn *) malloc(sizeof(block_q4_0xn) * nb); void * out_ptr_B, * out_ptr_B_start;
block_q4_0xn * out_ptr_B_start = out_ptr_B; if (nrows_interleaved == 8) {
out_ptr_B = (block_q4_0x8 *) malloc(sizeof(block_q4_0x8) * nb);
out_ptr_B_start = out_ptr_B;
}
else if (nrows_interleaved == 4) {
out_ptr_B = (block_q4_0x4 *) malloc(sizeof(block_q4_0x4) * nb);
out_ptr_B_start = out_ptr_B;
}
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
const block_q4_0 * in_ptrs[nrows_interleaved]; const block_q4_0 * in_ptrs[nrows_interleaved];
@ -3354,18 +3350,26 @@ size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, in
} }
for (int64_t x = 0; x < nb; x++) { for (int64_t x = 0; x < nb; x++) {
*out_ptr_B = make_block_q4_0xn(in_ptrs, blocklen_per_row, 0x88); if (nrows_interleaved == 8) {
out_ptr_B++; *(block_q4_0x8 *) out_ptr_B = make_block_q4_0x8(in_ptrs, blocklen_per_row, 0x88);
out_ptr_B = (block_q4_0x8 *) out_ptr_B + 1;
}
else if (nrows_interleaved == 4) {
*(block_q4_0x4 *) out_ptr_B = make_block_q4_0x4(in_ptrs, blocklen_per_row, 0x88);
out_ptr_B = (block_q4_0x4 *) out_ptr_B + 1;
}
for (int i = 0; i < nrows_interleaved; i++) { for (int i = 0; i < nrows_interleaved; i++) {
in_ptrs[i]++; in_ptrs[i]++;
} }
} }
out_ptr_B = out_ptr_B_start; out_ptr_B = out_ptr_B_start;
memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0xn) * nb); if (nrows_interleaved == 8) memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0x8) * nb);
else if (nrows_interleaved == 4) memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0x4) * nb);
} }
if (out_ptr_B_start) free(out_ptr_B_start); if (out_ptr_B_start) free(out_ptr_B_start);
return (nrow * n_per_row / QK4_0 * sizeof(block_q4_0));
return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
} }
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
char * qrow = (char *)dst; char * qrow = (char *)dst;
@ -15179,6 +15183,10 @@ void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int inpu
void ggml_gemv_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { void ggml_gemv_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE)
if (svcntw() != 8) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemv_q4_0_q8_0_aarch64_neon(depth, output_channels, height, s, vx, vy, ith, nth);
return;
}
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8); int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8); int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
size_t width = xend - x0; size_t width = xend - x0;
@ -15657,6 +15665,10 @@ void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_w
void ggml_gemm_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) { void ggml_gemm_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntw() != 8) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemm_q4_0_q8_0_aarch64_neon(depth, output_channels, height, s, vx, vy, ith, nth);
return;
}
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8); int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8); int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
size_t width = xend - x0; size_t width = xend - x0;