diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 7f642a11f..5af320b48 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -1783,13 +1783,22 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + // TODO: find the optimal values for these - if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + if (ne02 == ne12 && ne03 == ne13 && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) { diff --git a/ggml.c b/ggml.c index f50a1202c..b8fb8ac00 100644 --- a/ggml.c +++ b/ggml.c @@ -10423,11 +10423,6 @@ static void ggml_compute_forward_mul_mat( #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { - // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension - // ref: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); }