diff --git a/Makefile b/Makefile index b11efd961..d59111cef 100644 --- a/Makefile +++ b/Makefile @@ -874,9 +874,8 @@ ggml/src/ggml-cuda/%.o: \ $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< endif # GGML_HIPBLAS -ifdef GGML_RUNTIME_REPACK - MK_CPPFLAGS += -DGGML_USE_RUNTIME_REPACK - MK_CFLAGS += -DGGML_USE_RUNTIME_REPACK +ifndef GGML_NO_CPU_AARCH64 + MK_CPPFLAGS += -DGGML_USE_CPU_AARCH64 endif ifdef GGML_METAL @@ -888,7 +887,7 @@ ifdef GGML_METAL_NDEBUG endif ifdef GGML_METAL_EMBED_LIBRARY MK_CPPFLAGS += -DGGML_METAL_EMBED_LIBRARY - OBJ_GGML += ggml/src/ggml-metal-embed.o + OBJ_GGML += ggml/src/ggml-metal-embed.o endif endif # GGML_METAL diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6732f1880..8977d9197 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -92,7 +92,7 @@ else() endif() option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) -option(GGML_RUNTIME_REPACK "ggml: use runtime weight quantization to enable optimized GEMM/GEMV kernels for AARCH64 cpu" OFF) +option(GGML_CPU_AARCH64 "ggml: use runtime weight conversionn of Q4_0 to Q4_X_X" ON) option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 50d9cfd59..97b94f8bb 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -145,9 +145,10 @@ extern "C" { GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void); #endif -#ifdef GGML_USE_RUNTIME_REPACK GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void); -#endif + GGML_API bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft); + + #ifdef __cplusplus } diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index b7aa6de40..471ac4cf6 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -880,10 +880,10 @@ if (GGML_CPU_HBM) target_link_libraries(ggml PUBLIC memkind) endif() -if (GGML_RUNTIME_REPACK) - message(STATUS "Using runtime weight quantization to enable optimized GEMM/GEMV kernels for AARCH64 cpu") +if (GGML_CPU_AARCH64) + message(STATUS "Using runtime weight conversion of Q4_0 to Q4_0_x_x to enable optimized GEMM/GEMV kernels") - add_compile_definitions(GGML_USE_RUNTIME_REPACK) + add_compile_definitions(GGML_USE_CPU_AARCH64) endif() if (GGML_CANN) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 78ba8a0a4..a0a7c5ca3 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -3477,13 +3477,12 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * } } -#ifdef GGML_USE_RUNTIME_REPACK -static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * data, size_t data_size) { +static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); - block_q4_0x4 *dst = (block_q4_0x4 *)t->data; - const block_q4_0 *src = (const block_q4_0 *)data; + block_q4_0x4 * dst = (block_q4_0x4 *)t->data; + const block_q4_0 * src = (const block_q4_0 *)data; block_q4_0 dst_tmp[4]; int nrow = t->ne[1]; // Number of rows int nrows_interleaved = 4; @@ -3509,12 +3508,12 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } -static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * data, size_t data_size) { +static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); - block_q4_0x8 *dst = (block_q4_0x8*)t->data; - const block_q4_0 *src = (const block_q4_0*) data; + block_q4_0x8 * dst = (block_q4_0x8*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; block_q4_0 dst_tmp[8]; int nrow = t->ne[1]; // Number of rows int nrows_interleaved = 8; @@ -3541,42 +3540,47 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, } // Prepare for optimized kernels if applicable -int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, size_t data_size) { - GGML_ASSERT(cur->type == GGML_TYPE_Q4_0); -#if defined(__ARM_ARCH) - if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { - return repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size); - } - else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - return repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size); - } - else if (ggml_cpu_has_neon()) { - return repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size); - } -#endif - return -1; +void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) { + int ret = -1; - GGML_UNUSED(cur); - GGML_UNUSED(data); - GGML_UNUSED(data_size); + if (cur->type == repack_type) { + memcpy(cur->data, data, data_size); + return; + } + + GGML_ASSERT(cur->type == GGML_TYPE_Q4_0); + + switch (repack_type) { + case GGML_TYPE_Q4_0_8_8: + ret = repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size); + break; + case GGML_TYPE_Q4_0_4_8: + ret = repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size); + break; + case GGML_TYPE_Q4_0_4_4: + ret = repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size); + break; + default: + GGML_ABORT("Unsupported type"); + } + if (ret == -1) { + memcpy(cur->data, data, data_size); + } } -enum ggml_type ggml_get_optimal_type(const struct ggml_tensor * cur) { -#if defined(__ARM_ARCH) +enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) { if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { + // TODO: enable for AVX2 - currently disabled due to bad gemv performance + if (/* ggml_cpu_has_avx2() || */ (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { return GGML_TYPE_Q4_0_8_8; } - else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { return GGML_TYPE_Q4_0_4_8; } - else if (ggml_cpu_has_neon()) { + if (ggml_cpu_has_neon()) { return GGML_TYPE_Q4_0_4_4; } } -#endif - return cur->type; - GGML_UNUSED(cur); + return cur->type; } -#endif diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h index 74eddb060..cf3d4771c 100644 --- a/ggml/src/ggml-aarch64.h +++ b/ggml/src/ggml-aarch64.h @@ -33,10 +33,8 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -#ifdef GGML_USE_RUNTIME_REPACK -int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, size_t data_size); -enum ggml_type ggml_get_optimal_type(const struct ggml_tensor * cur); -#endif +void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size); +enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur); #ifdef __cplusplus } diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 5f8cb543c..573175bf2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2239,24 +2239,32 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { } #endif -#ifdef GGML_USE_RUNTIME_REPACK - // buffer type AARCH64 +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" +#endif + #include "ggml-aarch64.h" -static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - bool quantize = tensor->type == GGML_TYPE_Q4_0 && - tensor->op == GGML_OP_NONE && - strcmp(tensor->name, "token_embd.weight") != 0; +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif - if (quantize) { - GGML_ASSERT(offset == 0); - if (ggml_prepare_optimal_kernel(tensor, data, size) == 0) { - return; - } - } - memcpy((char *)tensor->data + offset, data, size); +static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *)ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + enum ggml_type repack_type = (enum ggml_type)(intptr_t)tensor->extra; + + ggml_aarch64_repack_tensor(tensor, repack_type, data, size); GGML_UNUSED(buffer); } @@ -2264,11 +2272,11 @@ static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buf static const struct ggml_backend_buffer_i ggml_backend_cpu_aarch64_buffer_i = { /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .init_tensor = */ NULL, // no initialization required + /* .init_tensor = */ ggml_backend_cpu_aarch64_buffer_init_tensor, /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_aarch64_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .get_tensor = */ NULL, + /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, }; @@ -2298,33 +2306,37 @@ ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) { /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + /* .is_host = */ NULL, }, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), - /* .context = */ NULL, + /* .context = */ NULL, }; return &ggml_backend_cpu_buffer_type_aarch64; } -#endif + +bool ggml_backend_cpu_buft_is_aarch64(ggml_backend_buffer_type_t buft) { + return buft == ggml_backend_cpu_aarch64_buffer_type(); +} static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) { - static ggml_backend_buffer_type_t bufts[3]; - int index = 0; + static std::vector bufts = []() { + std::vector bufts; #ifdef GGML_USE_CPU_HBM - bufts[index++] = ggml_backend_cpu_hbm_buffer_type(); + bufts.push_back(ggml_backend_cpu_hbm_buffer_type()); #endif -#ifdef GGML_USE_RUNTIME_REPACK - if (ggml_cpu_has_neon() || ggml_cpu_has_matmul_int8() || ggml_cpu_has_sve()) { - bufts[index++] = ggml_backend_cpu_aarch64_buffer_type(); - } +#ifdef GGML_USE_CPU_AARCH64 + bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); #endif - bufts[index] = NULL; // Terminate the list + bufts.push_back(NULL); - return bufts; + return bufts; + }(); + + return bufts.data(); GGML_UNUSED(device); } @@ -2635,15 +2647,21 @@ static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_b } static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { -#ifdef GGML_USE_RUNTIME_REPACK - const struct ggml_tensor *tensor = op->src[0]; - if (tensor && tensor->buffer && (strcmp(tensor->buffer->buft->iface.get_name(tensor->buffer->buft),"CPU_AARCH64") == 0)) { - if (op->op == GGML_OP_MUL_MAT && tensor->type == GGML_TYPE_Q4_0) { - return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits_cpu(tensor->type)->vec_dot_type; + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + if (src0 && src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { + if (op->op != GGML_OP_MUL_MAT || src0->type != GGML_TYPE_Q4_0 || ggml_aarch64_get_optimal_repack_type(src0) == GGML_TYPE_Q4_0) { + return false; } - return false; } -#endif + + for (int i = 1; i < GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) { + return false; + } + } + switch (op->op) { case GGML_OP_CPY: return @@ -2652,13 +2670,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st op->type != GGML_TYPE_IQ1_S && op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: - return op->src[1]->type == GGML_TYPE_F32;// FIXME || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type; + return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; case GGML_OP_ROPE_BACK: return op->src[2] == NULL && (op->op_params[2] & 4) == 0; case GGML_OP_IM2COL_BACK: - return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: - return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32; + return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; default: return true; } @@ -2667,7 +2685,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_host(buft); + return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_buft_is_aarch64(buft); GGML_UNUSED(dev); } @@ -2721,7 +2739,7 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch if (strcmp(name, "ggml_backend_set_n_threads") == 0) { return (void *)ggml_backend_cpu_set_n_threads; } - if (strcmp(name, "ggml_backend_cpu_get_extra_bufts") == 0) { + if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { return (void *)ggml_backend_cpu_get_extra_bufts; } @@ -2738,6 +2756,9 @@ static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = { }; ggml_backend_reg_t ggml_backend_cpu_reg(void) { + // init CPU feature detection + ggml_cpu_init(); + static struct ggml_backend_reg ggml_backend_cpu_reg = { /* .iface = */ ggml_backend_cpu_reg_i, /* .context = */ NULL, diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index 40ce0e5c4..3d12c941f 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -7325,6 +7325,7 @@ static void ggml_compute_forward_group_norm( static void ggml_compute_forward_mul_mat_one_chunk( const struct ggml_compute_params * params, struct ggml_tensor * dst, + const enum ggml_type type, const int64_t num_rows_per_vec_dot, const int64_t ir0_start, const int64_t ir0_end, @@ -7336,8 +7337,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( GGML_TENSOR_BINARY_OP_LOCALS - const enum ggml_type type = src0->type; - const bool src1_cont = ggml_is_contiguous(src1); ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; @@ -7427,11 +7426,9 @@ static void ggml_compute_forward_mul_mat( enum ggml_type type = src0->type; -#ifdef GGML_USE_RUNTIME_REPACK - if (strcmp(src0->buffer->buft->iface.get_name(src0->buffer->buft),"CPU_AARCH64") == 0) { - type = ggml_get_optimal_type(src0); + if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { + type = (enum ggml_type)(intptr_t)src0->extra; } -#endif enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float; @@ -7470,15 +7467,15 @@ static void ggml_compute_forward_mul_mat( if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type), + (const char *)data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), ith, nth, - src0->type, + type, src1->type, dst->type)) goto UseGgmlGemm1; @@ -7531,15 +7528,15 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/ggml_type_size(src0->type), + nb01/ggml_type_size(type), (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), ith, nth, - src0->type, + type, vec_dot_type, dst->type)) goto UseGgmlGemm2; @@ -7624,7 +7621,7 @@ UseGgmlGemm2:; const int64_t ir1_start = dr1 * ith1; const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); if (nth >= nchunk0 * nchunk1) { break; diff --git a/src/llama.cpp b/src/llama.cpp index 034441e1f..08e8d84da 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7189,7 +7189,7 @@ static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_cpu_get_extra_bufts"); + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); if (ggml_backend_dev_get_extra_bufts_fn) { ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); while (extra_bufts && *extra_bufts) {