ggml : fix handling of ops with n_threads > n_tasks > 1

This commit is contained in:
Georgi Gerganov 2023-06-26 20:50:50 +03:00
parent 4a555b4539
commit 81a40e9d61
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

12
ggml.c
View file

@ -16774,11 +16774,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
node_n = atomic_load(&state->shared->node_n);
} while (node_n == last);
}
// check if we should stop
if (node_n >= cgraph->n_nodes) break;
/* COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n];
struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_COMPUTE,
/*.ith =*/ state->ith,
@ -16789,8 +16791,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (state->ith < node->n_tasks) {
ggml_compute_forward(&params, node);
} else {
break;
}
}
@ -16952,7 +16952,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} break;
case GGML_OP_SCALE:
{
node->n_tasks = n_threads;
node->n_tasks = 1;
} break;
case GGML_OP_SET:
case GGML_OP_CONT:
@ -17165,9 +17165,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.shared = &state_shared,
};
int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}
workers[0].ith = 0;
@ -17185,9 +17184,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// join thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
int rc = ggml_thread_join(workers[j].thrd, NULL);
const int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}