From 69aaa3d78be65b7a924122a5fe2ddaf8ad626157 Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Tue, 26 Mar 2024 19:39:48 -0700 Subject: [PATCH] revert logic --- ggml-sycl.cpp | 70 ++++++++++++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index fd9250ade..061f99b27 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -15356,16 +15356,6 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && (ggml_is_quantized(src0->type) - || src0->type == GGML_TYPE_F16) && src1->type == GGML_TYPE_F32 - && dst->type == GGML_TYPE_F32; - bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && ggml_sycl_supports_mmq(src0->type); - #ifdef SYCL_USE_XMX const bool use_xmx = true; #else @@ -15380,11 +15370,6 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); -#ifdef GGML_SYCL_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_SYCL_FORCE_DMMV - - if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n"); @@ -15397,30 +15382,41 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 // KQ + KQV multi-batch // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n"); ggml_sycl_mul_mat_batched_sycl(src0, src1, dst); - } else if (use_dequantize_mul_mat_vec){ - // use ggml_sycl_op_dequantize_mul_mat_vec - //GGML_SYCL_DEBUG(""ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n""); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - } else if (use_mul_mat_vec_q){ - // use ggml_sycl_op_mul_mat_vec_q - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); - } else if (use_mul_mat_q){ - - if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) { - use_mul_mat_q = false; - } - - if (use_mul_mat_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } - } else if (src0->type == GGML_TYPE_F32){ + }else if (src0->type == GGML_TYPE_F32) { // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n"); ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { + // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n"); + if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) { +#ifdef GGML_SYCL_FORCE_DMMV + const bool use_mul_mat_vec_q = false; +#else + const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); +#endif // GGML_SYCL_FORCE_DMMV + + if (use_mul_mat_vec_q) { + // NOTE: this kernel does not support ggml_nrows(src1) > 1 + // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + } else { + // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + } + } else { + bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); + + if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) { + use_mul_mat_q = false; + } + + if (use_mul_mat_q) { + // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n"); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); + } else { + // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n"); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + } + } } else { GGML_ASSERT(false); }