revert logic

This commit is contained in:
abhilash1910 2024-03-26 19:39:48 -07:00
parent e9377baf7a
commit 69aaa3d78b

View file

@ -15356,16 +15356,6 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
}
}
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && (ggml_is_quantized(src0->type)
|| src0->type == GGML_TYPE_F16) && src1->type == GGML_TYPE_F32
&& dst->type == GGML_TYPE_F32;
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& ggml_sycl_supports_mmq(src0->type);
#ifdef SYCL_USE_XMX
const bool use_xmx = true;
#else
@ -15380,11 +15370,6 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
#ifdef GGML_SYCL_FORCE_DMMV
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
#endif // GGML_SYCL_FORCE_DMMV
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ single-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
@ -15397,15 +15382,28 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
// KQ + KQV multi-batch
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
} else if (use_dequantize_mul_mat_vec){
// use ggml_sycl_op_dequantize_mul_mat_vec
//GGML_SYCL_DEBUG(""ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
} else if (use_mul_mat_vec_q){
// use ggml_sycl_op_mul_mat_vec_q
}else if (src0->type == GGML_TYPE_F32) {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
// GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
#ifdef GGML_SYCL_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
#else
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
#endif // GGML_SYCL_FORCE_DMMV
if (use_mul_mat_vec_q) {
// NOTE: this kernel does not support ggml_nrows(src1) > 1
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
} else if (use_mul_mat_q){
} else {
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
}
} else {
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
use_mul_mat_q = false;
@ -15418,9 +15416,7 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
}
} else if (src0->type == GGML_TYPE_F32){
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
}
} else {
GGML_ASSERT(false);
}