Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions

This commit is contained in:
Dibakar Gope 2024-04-25 03:57:15 +00:00 committed by Dibakar Gope
parent 81215ff43a
commit 6c8d8266b1
2 changed files with 39 additions and 27 deletions

View file

@ -3309,41 +3309,37 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
if (!quant_weights) {
//quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
//return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
int nrows_interleaved, blocklen_per_row;
typedef block_q4_0x8 block_q4_0xn;
typedef block_q4_0xn (*make_block_q4_0xn_t)(const block_q4_0 *, unsigned int, unsigned int);
make_block_q4_0xn_t make_block_q4_0xn = make_block_q4_0x8;
if (ggml_cpu_has_sve() && (svcntw() == 8)) {
#if defined(__ARM_FEATURE_SVE)
if (svcntw() == 8) {
nrows_interleaved = 8;
blocklen_per_row = 8;
typedef block_q4_0x8 block_q4_0xn;
make_block_q4_0xn = make_block_q4_0x8;
}
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
nrows_interleaved = 4;
blocklen_per_row = 8;
typedef block_q4_0x4 block_q4_0xn;
make_block_q4_0xn = make_block_q4_0x4;
}
else if (ggml_cpu_has_neon()) {
nrows_interleaved = 4;
blocklen_per_row = 4;
typedef block_q4_0x4 block_q4_0xn;
make_block_q4_0xn = make_block_q4_0x4;
}
else {
assert(false);
}
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
nrows_interleaved = 4;
blocklen_per_row = 8;
#elif defined(__ARM_NEON)
nrows_interleaved = 4;
blocklen_per_row = 4;
#endif
assert(n_per_row % QK4_0 == 0);
const int nb = n_per_row / QK4_0;
block_q4_0xn * out_ptr_B = (block_q4_0xn *) malloc(sizeof(block_q4_0xn) * nb);
block_q4_0xn * out_ptr_B_start = out_ptr_B;
void * out_ptr_B, * out_ptr_B_start;
if (nrows_interleaved == 8) {
out_ptr_B = (block_q4_0x8 *) malloc(sizeof(block_q4_0x8) * nb);
out_ptr_B_start = out_ptr_B;
}
else if (nrows_interleaved == 4) {
out_ptr_B = (block_q4_0x4 *) malloc(sizeof(block_q4_0x4) * nb);
out_ptr_B_start = out_ptr_B;
}
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
const block_q4_0 * in_ptrs[nrows_interleaved];
@ -3354,18 +3350,26 @@ size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, in
}
for (int64_t x = 0; x < nb; x++) {
*out_ptr_B = make_block_q4_0xn(in_ptrs, blocklen_per_row, 0x88);
out_ptr_B++;
if (nrows_interleaved == 8) {
*(block_q4_0x8 *) out_ptr_B = make_block_q4_0x8(in_ptrs, blocklen_per_row, 0x88);
out_ptr_B = (block_q4_0x8 *) out_ptr_B + 1;
}
else if (nrows_interleaved == 4) {
*(block_q4_0x4 *) out_ptr_B = make_block_q4_0x4(in_ptrs, blocklen_per_row, 0x88);
out_ptr_B = (block_q4_0x4 *) out_ptr_B + 1;
}
for (int i = 0; i < nrows_interleaved; i++) {
in_ptrs[i]++;
}
}
out_ptr_B = out_ptr_B_start;
memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0xn) * nb);
if (nrows_interleaved == 8) memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0x8) * nb);
else if (nrows_interleaved == 4) memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0x4) * nb);
}
if (out_ptr_B_start) free(out_ptr_B_start);
return (nrow * n_per_row / QK4_0 * sizeof(block_q4_0));
return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
}
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
char * qrow = (char *)dst;
@ -15179,6 +15183,10 @@ void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int inpu
void ggml_gemv_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE)
if (svcntw() != 8) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemv_q4_0_q8_0_aarch64_neon(depth, output_channels, height, s, vx, vy, ith, nth);
return;
}
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
size_t width = xend - x0;
@ -15657,6 +15665,10 @@ void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_w
void ggml_gemm_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (svcntw() != 8) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemm_q4_0_q8_0_aarch64_neon(depth, output_channels, height, s, vx, vy, ith, nth);
return;
}
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
size_t width = xend - x0;

View file

@ -12377,7 +12377,7 @@ UseGgmlGemm2:;
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
gemv(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
}
}
}
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (type == GGML_TYPE_Q4_0_AARCH64)) {
// use batch-sized 4 GEMM kernel
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {