update mul_mat condition
This commit is contained in:
parent
2a034d2b41
commit
abe11feab6
1 changed files with 2 additions and 18 deletions
|
@ -85,7 +85,6 @@ Following definition copied from DPCT head files, which are used by ggml-sycl.cp
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool ggml_sycl_loaded(void);
|
bool ggml_sycl_loaded(void);
|
||||||
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
|
||||||
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
||||||
void ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
|
void ggml_sycl_assign_buffers(struct ggml_tensor * tensor);
|
||||||
void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
void ggml_sycl_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||||
|
@ -11375,21 +11374,6 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tenso
|
||||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_sycl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
|
||||||
if (!g_sycl_loaded) return false;
|
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
|
||||||
|
|
||||||
const int64_t ne0 = dst->ne[0];
|
|
||||||
const int64_t ne1 = dst->ne[1];
|
|
||||||
|
|
||||||
// TODO: find the optimal values for these
|
|
||||||
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
|
||||||
src1->type == GGML_TYPE_F32 &&
|
|
||||||
dst->type == GGML_TYPE_F32 &&
|
|
||||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||||
const ggml_tensor *src1,
|
const ggml_tensor *src1,
|
||||||
ggml_tensor *dst) try {
|
ggml_tensor *dst) try {
|
||||||
|
@ -12195,13 +12179,13 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
func = ggml_sycl_rms_norm;
|
func = ggml_sycl_rms_norm;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
if (ggml_sycl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
|
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
func = ggml_sycl_mul_mat;
|
func = ggml_sycl_mul_mat;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
if (ggml_sycl_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
|
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
func = ggml_sycl_mul_mat_id;
|
func = ggml_sycl_mul_mat_id;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue