From 6609c229e84e7cb749a5a3902f0123033c06c523 Mon Sep 17 00:00:00 2001 From: mqy Date: Mon, 19 Jun 2023 01:05:34 +0800 Subject: [PATCH] fixed OP_OUT_PROD and OP_NONE --- ggml.c | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index 5a9e0b33e..75a562481 100644 --- a/ggml.c +++ b/ggml.c @@ -15655,8 +15655,7 @@ int ggml_get_task_profiles( p[0].stages[1].valid = true; p[0].stages[1].parallel = true; } break; - case GGML_OP_MUL_MAT: - case GGML_OP_OUT_PROD: { // FIXME: is this correct? + case GGML_OP_MUL_MAT: { enum ggml_type src0_t = tensor->src0->type; if (src0_t == GGML_TYPE_F32) { p[0].stages[1].valid = true; @@ -15673,6 +15672,15 @@ int ggml_get_task_profiles( GGML_ASSERT(false); } } break; + case GGML_OP_OUT_PROD: { + enum ggml_type src0_t = tensor->src0->type; + if (src0_t == GGML_TYPE_F32) { + p[0].stages[1].valid = true; + p[0].stages[1].parallel = true; + } else { + GGML_ASSERT(false); + } + } break; case GGML_OP_SCALE: { p[0].stages[1].valid = true; p[0].stages[1].parallel = true; @@ -15810,13 +15818,12 @@ static void ggml_optimize_tensor_task_profile( struct ggml_tensor *tensor, struct ggml_task_profile *profiles, int n_profiles, struct ggml_mulmat_tune *tune) { - if (tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_OUT_PROD) { + if (tensor->op != GGML_OP_MUL_MAT) { return; } GGML_ASSERT(tensor); - GGML_ASSERT(tensor->op == GGML_OP_MUL_MAT || - tensor->op == GGML_OP_OUT_PROD); + GGML_ASSERT(tensor->op == GGML_OP_MUL_MAT); GGML_ASSERT(tensor->task_profile.id == n_profiles); GGML_ASSERT(profiles); @@ -15949,7 +15956,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; - GGML_ASSERT (node->op != GGML_OP_NONE); + if (node->op == GGML_OP_NONE) { + continue; + } if (node->task_profile.id == 0) { ggml_set_tensor_task_profile(node, cgraph->tune); @@ -16031,7 +16040,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { } break; case GGML_OP_MUL_MAT: - case GGML_OP_OUT_PROD: // FIXME: is this correct? + case GGML_OP_OUT_PROD: { size_t cur = 0; GGML_ASSERT(node->src1->type == GGML_TYPE_F32);