CLBlast: Add broadcast support for matrix multiplication (#3402)

Broadcast src0 into src1 across dimensions 2 and 3 when needed.
This is required for models that use GQA.
This commit is contained in:
shibe2 2023-10-02 23:26:15 +04:00 committed by GitHub
parent 29a404a951
commit 665018c749
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 28 deletions

5
ggml.c
View file

@ -11621,11 +11621,6 @@ static void ggml_compute_forward_mul_mat(
#if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}