update mul_mat condition

This commit is contained in:
Meng, Hengyu 2024-06-04 16:47:23 +08:00
parent 2a034d2b41
commit abe11feab6

View file

@ -85,7 +85,6 @@ Following definition copied from DPCT head files, which are used by ggml-sycl.cp
#endif #endif
bool ggml_sycl_loaded(void); bool ggml_sycl_loaded(void);
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_sycl_free_data(struct ggml_tensor * tensor); void ggml_sycl_free_data(struct ggml_tensor * tensor);
void ggml_sycl_assign_buffers(struct ggml_tensor * tensor); void ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor); void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
@ -11375,21 +11374,6 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tenso
GGML_SYCL_DEBUG("call %s done\n", __func__); GGML_SYCL_DEBUG("call %s done\n", __func__);
} }
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
if (!g_sycl_loaded) return false;
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
}
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, const ggml_tensor *src1,
ggml_tensor *dst) try { ggml_tensor *dst) try {
@ -12195,13 +12179,13 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
func = ggml_sycl_rms_norm; func = ggml_sycl_rms_norm;
break; break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (ggml_sycl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
return false; return false;
} }
func = ggml_sycl_mul_mat; func = ggml_sycl_mul_mat;
break; break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
if (ggml_sycl_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) { if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
return false; return false;
} }
func = ggml_sycl_mul_mat_id; func = ggml_sycl_mul_mat_id;