From f2e8616d1f0dedf503b7c360905712de1d9eae7b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Dec 2023 19:07:27 +0200 Subject: [PATCH] ggml : restore `ggml_get_n_tasks()` logic in `ggml_graph_plan()` --- CMakeLists.txt | 8 ++++---- ggml.c | 25 +++---------------------- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index baa6233de..78de2dd1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -662,11 +662,11 @@ add_library(ggml OBJECT ggml-backend.h ggml-quants.c ggml-quants.h - ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} - ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} - ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} - ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} + ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} + ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} + ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} ) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) diff --git a/ggml.c b/ggml.c index 4a4a9d51c..ca56f063c 100644 --- a/ggml.c +++ b/ggml.c @@ -8459,6 +8459,7 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; + const int nth = params->nth; GGML_TENSOR_BINARY_OP_LOCALS @@ -16100,18 +16101,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { // thread scheduling for the different operations + work buffer size estimation for (int i = 0; i < cgraph->n_nodes; i++) { - int n_tasks = 1; - struct ggml_tensor * node = cgraph->nodes[i]; + const int n_tasks = ggml_get_n_tasks(node, n_threads); + size_t cur = 0; switch (node->op) { case GGML_OP_CPY: case GGML_OP_DUP: { - n_tasks = n_threads; - if (ggml_is_quantized(node->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } @@ -16119,16 +16118,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_ADD: case GGML_OP_ADD1: { - n_tasks = n_threads; - if (ggml_is_quantized(node->src[0]->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } } break; case GGML_OP_ACC: { - n_tasks = n_threads; - if (ggml_is_quantized(node->src[0]->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; } @@ -16173,8 +16168,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_OUT_PROD: { - n_tasks = n_threads; - if (ggml_is_quantized(node->src[0]->type)) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } @@ -16208,10 +16201,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { GGML_ASSERT(false); } } break; - case GGML_OP_IM2COL: - { - n_tasks = n_threads; - } break; case GGML_OP_CONV_TRANSPOSE_2D: { const int64_t ne00 = node->src[0]->ne[0]; // W @@ -16228,8 +16217,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_FLASH_ATTN: { - n_tasks = n_threads; - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); if (node->src[1]->type == GGML_TYPE_F32) { @@ -16242,8 +16229,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_FLASH_FF: { - n_tasks = n_threads; - if (node->src[1]->type == GGML_TYPE_F32) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 @@ -16254,8 +16239,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_FLASH_ATTN_BACK: { - n_tasks = n_threads; - const int64_t D = node->src[0]->ne[0]; const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back @@ -16270,8 +16253,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_CROSS_ENTROPY_LOSS: { - n_tasks = n_threads; - cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; case GGML_OP_COUNT: