ggml : fix handling of ops with n_threads > n_tasks > 1
This commit is contained in:
parent
4a555b4539
commit
81a40e9d61
1 changed files with 7 additions and 9 deletions
14
ggml.c
14
ggml.c
|
@ -16774,11 +16774,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
node_n = atomic_load(&state->shared->node_n);
|
node_n = atomic_load(&state->shared->node_n);
|
||||||
} while (node_n == last);
|
} while (node_n == last);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we should stop
|
// check if we should stop
|
||||||
if (node_n >= cgraph->n_nodes) break;
|
if (node_n >= cgraph->n_nodes) break;
|
||||||
|
|
||||||
/* COMPUTE */
|
/* COMPUTE */
|
||||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||||
|
|
||||||
struct ggml_compute_params params = {
|
struct ggml_compute_params params = {
|
||||||
/*.type =*/ GGML_TASK_COMPUTE,
|
/*.type =*/ GGML_TASK_COMPUTE,
|
||||||
/*.ith =*/ state->ith,
|
/*.ith =*/ state->ith,
|
||||||
|
@ -16787,10 +16789,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
|
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
if(state->ith < node->n_tasks) {
|
if (state->ith < node->n_tasks) {
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16952,7 +16952,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
node->n_tasks = n_threads;
|
node->n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SET:
|
case GGML_OP_SET:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
@ -17165,9 +17165,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
.shared = &state_shared,
|
.shared = &state_shared,
|
||||||
};
|
};
|
||||||
|
|
||||||
int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
|
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
|
||||||
GGML_ASSERT(rc == 0);
|
GGML_ASSERT(rc == 0);
|
||||||
UNUSED(rc);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
workers[0].ith = 0;
|
workers[0].ith = 0;
|
||||||
|
@ -17185,9 +17184,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
// join thread pool
|
// join thread pool
|
||||||
if (n_threads > 1) {
|
if (n_threads > 1) {
|
||||||
for (int j = 1; j < n_threads; j++) {
|
for (int j = 1; j < n_threads; j++) {
|
||||||
int rc = ggml_thread_join(workers[j].thrd, NULL);
|
const int rc = ggml_thread_join(workers[j].thrd, NULL);
|
||||||
GGML_ASSERT(rc == 0);
|
GGML_ASSERT(rc == 0);
|
||||||
UNUSED(rc);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue