From de65783ba22f8a5423684b37aa56a3df0b324f0d Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Tue, 16 May 2023 09:26:03 +0200 Subject: [PATCH] Broadcasting for ggml_mul --- ggml.c | 50 +++++++++++++++++++++++++++++--------------------- llama.cpp | 18 +++++++++++++++--- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/ggml.c b/ggml.c index 4311ce7cf..e937c0a33 100644 --- a/ggml.c +++ b/ggml.c @@ -4643,7 +4643,7 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->ne[0] == b->ne[0] && ggml_can_repeat(b, a)); bool is_node = false; @@ -7945,7 +7945,16 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + const int nr = ggml_nrows(src0); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src0->ne[3]; + + GGML_ASSERT(ne00 == ne10 && ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -7953,11 +7962,6 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; - const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb02 = src0->nb[2]; @@ -7976,12 +7980,12 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (nb10 == sizeof(float) && ggml_are_same_shape(src0, src1)) { for (int ir = ith; ir < nr; ir += nth) { // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); #ifdef GGML_USE_ACCELERATE @@ -7991,9 +7995,9 @@ static void ggml_compute_forward_mul_f32( (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + ne00); #else - ggml_vec_mul_f32(ne0, + ggml_vec_mul_f32(ne00, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); @@ -8004,15 +8008,19 @@ static void ggml_compute_forward_mul_f32( } else { // src1 is not contiguous for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int i13 = i03 % ne13; + const int i12 = i02 % ne12; + const int i11 = i01 % ne11; - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + for (int i0 = 0; i0 < ne00; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); } diff --git a/llama.cpp b/llama.cpp index 98f49abd7..b594789fb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1154,10 +1154,14 @@ static bool llama_eval_internal( { cur = ggml_rms_norm(ctx0, inpL); - // cur = attention_norm*cur + // cur = cur*attention_norm(broadcasted) +#ifdef GGML_USE_CUBLAS + cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); +#else cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].attention_norm, cur), cur); +#endif } // self-attention @@ -1264,10 +1268,14 @@ static bool llama_eval_internal( { cur = ggml_rms_norm(ctx0, inpFF); - // cur = ffn_norm*cur + // cur = cur*ffn_norm(broadcasted) +#ifdef GGML_USE_CUBLAS + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); +#else cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), cur); +#endif } struct ggml_tensor * tmp = ggml_mul_mat(ctx0, @@ -1304,10 +1312,14 @@ static bool llama_eval_internal( inpL = ggml_rms_norm(ctx0, inpL); - // inpL = norm*inpL + // inpL = inpL*norm(broadcasted) +#ifdef GGML_USE_CUBLAS + inpL = ggml_mul(ctx0, inpL, model.norm); +#else inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL); +#endif embeddings = inpL; }