CLBlast: Add broadcast support for matrix multiplication (#3402)
Broadcast src0 into src1 across dimensions 2 and 3 when needed. This is required for models that use GQA.
This commit is contained in:
parent
29a404a951
commit
665018c749
2 changed files with 67 additions and 28 deletions
5
ggml.c
5
ggml.c
|
@ -11621,11 +11621,6 @@ static void ggml_compute_forward_mul_mat(
|
|||
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
||||
// ref: https://github.com/ggerganov/ggml/pull/224
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue