ggml : fix handling of ops with n_threads > n_tasks > 1

This commit is contained in:
Georgi Gerganov 2023-06-26 20:50:50 +03:00
parent 4a555b4539
commit 81a40e9d61
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

12
ggml.c
View file

@ -16774,11 +16774,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
node_n = atomic_load(&state->shared->node_n); node_n = atomic_load(&state->shared->node_n);
} while (node_n == last); } while (node_n == last);
} }
// check if we should stop // check if we should stop
if (node_n >= cgraph->n_nodes) break; if (node_n >= cgraph->n_nodes) break;
/* COMPUTE */ /* COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n]; struct ggml_tensor * node = cgraph->nodes[node_n];
struct ggml_compute_params params = { struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_COMPUTE, /*.type =*/ GGML_TASK_COMPUTE,
/*.ith =*/ state->ith, /*.ith =*/ state->ith,
@ -16789,8 +16791,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (state->ith < node->n_tasks) { if (state->ith < node->n_tasks) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
} else {
break;
} }
} }
@ -16952,7 +16952,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
node->n_tasks = n_threads; node->n_tasks = 1;
} break; } break;
case GGML_OP_SET: case GGML_OP_SET:
case GGML_OP_CONT: case GGML_OP_CONT:
@ -17165,9 +17165,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.shared = &state_shared, .shared = &state_shared,
}; };
int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0); GGML_ASSERT(rc == 0);
UNUSED(rc);
} }
} }
workers[0].ith = 0; workers[0].ith = 0;
@ -17185,9 +17184,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// join thread pool // join thread pool
if (n_threads > 1) { if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) { for (int j = 1; j < n_threads; j++) {
int rc = ggml_thread_join(workers[j].thrd, NULL); const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0); GGML_ASSERT(rc == 0);
UNUSED(rc);
} }
} }