Arm AArch64: add optimized GEMV and GEMM asm kernels for q4_0_q8_0 quantization and refactor code to address llama.cpp pr#5780 suggestions
This commit is contained in:
parent
81215ff43a
commit
6c8d8266b1
2 changed files with 39 additions and 27 deletions
|
@ -3309,41 +3309,37 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
|||
|
||||
size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
//quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
//return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
||||
|
||||
int nrows_interleaved, blocklen_per_row;
|
||||
typedef block_q4_0x8 block_q4_0xn;
|
||||
typedef block_q4_0xn (*make_block_q4_0xn_t)(const block_q4_0 *, unsigned int, unsigned int);
|
||||
make_block_q4_0xn_t make_block_q4_0xn = make_block_q4_0x8;
|
||||
|
||||
if (ggml_cpu_has_sve() && (svcntw() == 8)) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
if (svcntw() == 8) {
|
||||
nrows_interleaved = 8;
|
||||
blocklen_per_row = 8;
|
||||
typedef block_q4_0x8 block_q4_0xn;
|
||||
make_block_q4_0xn = make_block_q4_0x8;
|
||||
}
|
||||
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
nrows_interleaved = 4;
|
||||
blocklen_per_row = 8;
|
||||
typedef block_q4_0x4 block_q4_0xn;
|
||||
make_block_q4_0xn = make_block_q4_0x4;
|
||||
}
|
||||
else if (ggml_cpu_has_neon()) {
|
||||
nrows_interleaved = 4;
|
||||
blocklen_per_row = 4;
|
||||
typedef block_q4_0x4 block_q4_0xn;
|
||||
make_block_q4_0xn = make_block_q4_0x4;
|
||||
}
|
||||
else {
|
||||
assert(false);
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
nrows_interleaved = 4;
|
||||
blocklen_per_row = 8;
|
||||
#elif defined(__ARM_NEON)
|
||||
nrows_interleaved = 4;
|
||||
blocklen_per_row = 4;
|
||||
#endif
|
||||
|
||||
assert(n_per_row % QK4_0 == 0);
|
||||
const int nb = n_per_row / QK4_0;
|
||||
|
||||
block_q4_0xn * out_ptr_B = (block_q4_0xn *) malloc(sizeof(block_q4_0xn) * nb);
|
||||
block_q4_0xn * out_ptr_B_start = out_ptr_B;
|
||||
void * out_ptr_B, * out_ptr_B_start;
|
||||
if (nrows_interleaved == 8) {
|
||||
out_ptr_B = (block_q4_0x8 *) malloc(sizeof(block_q4_0x8) * nb);
|
||||
out_ptr_B_start = out_ptr_B;
|
||||
}
|
||||
else if (nrows_interleaved == 4) {
|
||||
out_ptr_B = (block_q4_0x4 *) malloc(sizeof(block_q4_0x4) * nb);
|
||||
out_ptr_B_start = out_ptr_B;
|
||||
}
|
||||
|
||||
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
|
||||
const block_q4_0 * in_ptrs[nrows_interleaved];
|
||||
|
@ -3354,18 +3350,26 @@ size_t quantize_q4_0_aarch64(const float * restrict src, void * restrict dst, in
|
|||
}
|
||||
|
||||
for (int64_t x = 0; x < nb; x++) {
|
||||
*out_ptr_B = make_block_q4_0xn(in_ptrs, blocklen_per_row, 0x88);
|
||||
out_ptr_B++;
|
||||
if (nrows_interleaved == 8) {
|
||||
*(block_q4_0x8 *) out_ptr_B = make_block_q4_0x8(in_ptrs, blocklen_per_row, 0x88);
|
||||
out_ptr_B = (block_q4_0x8 *) out_ptr_B + 1;
|
||||
}
|
||||
else if (nrows_interleaved == 4) {
|
||||
*(block_q4_0x4 *) out_ptr_B = make_block_q4_0x4(in_ptrs, blocklen_per_row, 0x88);
|
||||
out_ptr_B = (block_q4_0x4 *) out_ptr_B + 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
in_ptrs[i]++;
|
||||
}
|
||||
}
|
||||
out_ptr_B = out_ptr_B_start;
|
||||
memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0xn) * nb);
|
||||
if (nrows_interleaved == 8) memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0x8) * nb);
|
||||
else if (nrows_interleaved == 4) memcpy ((block_q4_0 *) dst + b / QK4_0, out_ptr_B_start, sizeof(block_q4_0x4) * nb);
|
||||
}
|
||||
if (out_ptr_B_start) free(out_ptr_B_start);
|
||||
return (nrow * n_per_row / QK4_0 * sizeof(block_q4_0));
|
||||
|
||||
return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
||||
char * qrow = (char *)dst;
|
||||
|
@ -15179,6 +15183,10 @@ void ggml_gemv_q4_0_q8_0_blocked8_sve(const int n, int output_channels, int inpu
|
|||
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
if (svcntw() != 8) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemv_q4_0_q8_0_aarch64_neon(depth, output_channels, height, s, vx, vy, ith, nth);
|
||||
return;
|
||||
}
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
|
||||
size_t width = xend - x0;
|
||||
|
@ -15657,6 +15665,10 @@ void ggml_gemm_q4_0_q8_0(const int n, int rows, int output_channels, int input_w
|
|||
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_sve256(size_t depth, size_t output_channels, size_t height, float * restrict s, const void * restrict vx, const void * restrict vy, int ith, int nth) {
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntw() != 8) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) ggml_gemm_q4_0_q8_0_aarch64_neon(depth, output_channels, height, s, vx, vy, ith, nth);
|
||||
return;
|
||||
}
|
||||
int64_t x0 = roundup((ith * output_channels) / nth, (int64_t)8);
|
||||
int64_t xend = roundup(((ith + 1) * output_channels) / nth, (int64_t)8);
|
||||
size_t width = xend - x0;
|
||||
|
|
|
@ -12377,7 +12377,7 @@ UseGgmlGemm2:;
|
|||
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
|
||||
gemv(ne00, ne01, 1, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), ith, nth);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (type == GGML_TYPE_Q4_0_AARCH64)) {
|
||||
// use batch-sized 4 GEMM kernel
|
||||
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue