From 3f7233b62e04056ff8d59e8f6dc816b292ec3bf0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Jan 2024 13:33:27 -0500 Subject: [PATCH] ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. --- ggml.c | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/ggml.c b/ggml.c index a5b337d96..8f351d823 100644 --- a/ggml.c +++ b/ggml.c @@ -8671,27 +8671,32 @@ static void ggml_compute_forward_exp_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; } - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + const int ith = params->ith; + const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); - for (int64_t i3 = 0; i3 < ne03; i3++) { - for (int64_t i2 = 0; i2 < ne02; i2++) { - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); - float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); - ggml_vec_exp_f32(ne00, dst_row, src_row); - } - } - } + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_exp_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + }; } static void ggml_compute_forward_exp( @@ -17413,13 +17418,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_ACC: + case GGML_OP_EXP: { n_tasks = n_threads; } break; case GGML_OP_SUB: case GGML_OP_SQR: case GGML_OP_SQRT: - case GGML_OP_EXP: case GGML_OP_LOG: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: