From e098171aa772c83e7b6d5ef5721e0761697f6b94 Mon Sep 17 00:00:00 2001 From: Kunnis Date: Wed, 8 May 2024 23:10:50 -0500 Subject: [PATCH] Passing around the state --- ggml.c | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index 8a5149c1b..06d2c1bf6 100644 --- a/ggml.c +++ b/ggml.c @@ -11771,7 +11771,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) { static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + struct ggml_compute_state* state) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -17481,7 +17482,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { @@ -17579,7 +17580,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor); + ggml_compute_forward_mul_mat(params, tensor, state); } break; case GGML_OP_MUL_MAT_ID: { @@ -19639,7 +19640,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_tensor * node = cgraph->nodes[node_n]; if (GGML_OP_HAS_FINALIZE[node->op]) { params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } ggml_graph_compute_perf_stats_node(node, state->shared); } @@ -19659,17 +19660,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /* INIT */ if (GGML_OP_HAS_INIT[node->op]) { params.type = GGML_TASK_TYPE_INIT; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, // they do something more efficient than spinning (?) params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); if (GGML_OP_HAS_FINALIZE[node->op]) { params.type = GGML_TASK_TYPE_FINALIZE; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } ggml_graph_compute_perf_stats_node(node, state->shared); @@ -19708,7 +19709,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->ith < n_tasks) { if (GGML_OP_HAS_INIT[node->op]) { - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } } @@ -19729,7 +19730,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->ith < n_tasks) { params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node); + ggml_compute_forward(¶ms, node, state); } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {