refactor logic
This commit is contained in:
parent
32589a642f
commit
a553def52e
1 changed files with 56 additions and 33 deletions
|
@ -4694,6 +4694,24 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|||
#endif
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
|
@ -14082,6 +14100,7 @@ inline void ggml_sycl_op_mul_mat_vec_q(
|
|||
const dpct::queue_ptr &stream) {
|
||||
|
||||
//GGML_ASSERT(ggml_nrows(src1) == 1);
|
||||
//GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
@ -15594,7 +15613,18 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
min_compute_capability = g_device_caps[i].cc;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= XMX_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32
|
||||
&& dst->type == GGML_TYPE_F32;
|
||||
|
||||
|
||||
#ifdef SYCL_USE_XMX
|
||||
const bool use_xmx = true;
|
||||
#else
|
||||
|
@ -15609,6 +15639,11 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
#ifdef GGML_SYCL_FORCE_DMMV
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
#endif // GGML_SYCL_FORCE_DMMV
|
||||
|
||||
|
||||
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
|
||||
|
@ -15621,43 +15656,31 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
// KQ + KQV multi-batch
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
|
||||
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
// GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
|
||||
#ifdef GGML_SYCL_FORCE_DMMV
|
||||
const bool use_mul_mat_vec_q = false;
|
||||
#else
|
||||
const bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
|
||||
#endif // GGML_SYCL_FORCE_DMMV
|
||||
|
||||
if (use_mul_mat_vec_q) {
|
||||
// NOTE: this kernel does not support ggml_nrows(src1) > 1
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
||||
} else {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
||||
}
|
||||
} else {
|
||||
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
|
||||
|
||||
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
|
||||
} else if (use_dequantize_mul_mat_vec){
|
||||
// use ggml_sycl_op_dequantize_mul_mat_vec
|
||||
//GGML_SYCL_DEBUG(""ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
||||
} else if (use_mul_mat_vec_q){
|
||||
// use ggml_sycl_op_mul_mat_vec_q
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
||||
|
||||
} else if (use_mul_mat_q){
|
||||
|
||||
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
|
||||
use_mul_mat_q = false;
|
||||
}
|
||||
|
||||
if (use_mul_mat_q) {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
|
||||
} else {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
}
|
||||
if (use_mul_mat_q) {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
|
||||
} else {
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
|
||||
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue