revert logic
This commit is contained in:
parent
e9377baf7a
commit
69aaa3d78b
1 changed files with 33 additions and 37 deletions
|
@ -15356,16 +15356,6 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
||||||
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
|
|
||||||
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && (ggml_is_quantized(src0->type)
|
|
||||||
|| src0->type == GGML_TYPE_F16) && src1->type == GGML_TYPE_F32
|
|
||||||
&& dst->type == GGML_TYPE_F32;
|
|
||||||
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type)
|
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
||||||
&& ggml_sycl_supports_mmq(src0->type);
|
|
||||||
|
|
||||||
#ifdef SYCL_USE_XMX
|
#ifdef SYCL_USE_XMX
|
||||||
const bool use_xmx = true;
|
const bool use_xmx = true;
|
||||||
#else
|
#else
|
||||||
|
@ -15380,11 +15370,6 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||||
|
|
||||||
#ifdef GGML_SYCL_FORCE_DMMV
|
|
||||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
|
||||||
#endif // GGML_SYCL_FORCE_DMMV
|
|
||||||
|
|
||||||
|
|
||||||
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||||
// KQ single-batch
|
// KQ single-batch
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
|
||||||
|
@ -15397,15 +15382,28 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||||
// KQ + KQV multi-batch
|
// KQ + KQV multi-batch
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
|
||||||
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
|
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
|
||||||
} else if (use_dequantize_mul_mat_vec){
|
}else if (src0->type == GGML_TYPE_F32) {
|
||||||
// use ggml_sycl_op_dequantize_mul_mat_vec
|
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
|
||||||
//GGML_SYCL_DEBUG(""ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"");
|
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||||
} else if (use_mul_mat_vec_q){
|
// GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
|
||||||
// use ggml_sycl_op_mul_mat_vec_q
|
if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
|
||||||
|
#ifdef GGML_SYCL_FORCE_DMMV
|
||||||
|
const bool use_mul_mat_vec_q = false;
|
||||||
|
#else
|
||||||
|
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
|
||||||
|
#endif // GGML_SYCL_FORCE_DMMV
|
||||||
|
|
||||||
|
if (use_mul_mat_vec_q) {
|
||||||
|
// NOTE: this kernel does not support ggml_nrows(src1) > 1
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
|
||||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
||||||
} else if (use_mul_mat_q){
|
} else {
|
||||||
|
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
|
||||||
|
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
|
||||||
|
|
||||||
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
|
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
|
||||||
use_mul_mat_q = false;
|
use_mul_mat_q = false;
|
||||||
|
@ -15418,9 +15416,7 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
|
||||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||||
}
|
}
|
||||||
} else if (src0->type == GGML_TYPE_F32){
|
}
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
|
|
||||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue