check abort_callback on main thread only
This commit is contained in:
parent
d27f26ea0c
commit
486d06106c
1 changed files with 92 additions and 138 deletions
222
ggml.c
222
ggml.c
|
@ -1744,13 +1744,14 @@ struct ggml_compute_state_shared {
|
|||
void * abort_callback_data;
|
||||
|
||||
atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
|
||||
|
||||
enum ggml_status ec;
|
||||
};
|
||||
|
||||
struct ggml_compute_state {
|
||||
ggml_thread_t thrd;
|
||||
int ith;
|
||||
struct ggml_compute_state_shared * shared;
|
||||
enum ggml_status ec;
|
||||
};
|
||||
|
||||
struct ggml_compute_params {
|
||||
|
@ -3001,7 +3002,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
|
|||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(numa_flag);
|
||||
UNUSED(numa_flag);
|
||||
// TODO
|
||||
#endif
|
||||
}
|
||||
|
@ -15980,7 +15981,7 @@ static void ggml_compute_forward_unary(
|
|||
static void ggml_compute_forward_get_rel_pos_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_UNUSED(params);
|
||||
UNUSED(params);
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
|
@ -18317,8 +18318,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
|
||||
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
|
@ -18341,24 +18342,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
// FIXME: the cost of launching additional threads decreases performance with GPU offloading
|
||||
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
||||
// decreases performance with GPU offloading
|
||||
//n_tasks = n_threads;
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
|
@ -18390,14 +18383,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
{
|
||||
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
@ -18408,33 +18395,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ARANGE:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
|
@ -18482,9 +18448,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
}
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
@ -18514,37 +18477,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
return n_tasks;
|
||||
}
|
||||
|
||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||
|
||||
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
||||
const struct ggml_cplan * cplan = state->shared->cplan;
|
||||
|
||||
set_numa_thread_affinity(state->ith);
|
||||
|
||||
struct ggml_compute_params params = {
|
||||
/*.ith =*/ state->ith,
|
||||
/*.nth =*/ state->shared->n_threads,
|
||||
/*.wsize =*/ cplan->work_size,
|
||||
/*.wdata =*/ cplan->work_data,
|
||||
/*.shared=*/ state->shared,
|
||||
};
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
|
||||
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
state->ec = GGML_STATUS_ABORTED;
|
||||
return 0;
|
||||
}
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
ggml_barrier(state->shared);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = GGML_DEFAULT_N_THREADS;
|
||||
|
@ -18713,8 +18645,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|||
return cplan;
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
|
||||
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
|
||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||
|
||||
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
||||
const struct ggml_cplan * cplan = state->shared->cplan;
|
||||
|
||||
set_numa_thread_affinity(state->ith);
|
||||
|
||||
struct ggml_compute_params params = {
|
||||
/*.ith =*/ state->ith,
|
||||
/*.nth =*/ state->shared->n_threads,
|
||||
/*.wsize =*/ cplan->work_size,
|
||||
/*.wdata =*/ cplan->work_data,
|
||||
/*.shared=*/ state->shared,
|
||||
};
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
state->shared->ec = GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
ggml_barrier(state->shared);
|
||||
|
||||
if (state->shared->ec != GGML_STATUS_SUCCESS) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
||||
GGML_ASSERT(cplan);
|
||||
GGML_ASSERT(cplan->n_threads > 0);
|
||||
GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
|
||||
|
||||
int n_threads = cplan->n_threads;
|
||||
|
||||
struct ggml_compute_state_shared state_shared = {
|
||||
/*.cgraph =*/ cgraph,
|
||||
/*.cgraph_plan =*/ cplan,
|
||||
/*.n_threads =*/ n_threads,
|
||||
/*.n_barrier =*/ 0,
|
||||
/*.n_barrier_passed =*/ 0,
|
||||
/*.abort_callback =*/ NULL,
|
||||
/*.abort_callback_data =*/ NULL,
|
||||
/*.current_chunk =*/ 0,
|
||||
/*.ec =*/ GGML_STATUS_SUCCESS,
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_OPENMP
|
||||
if (n_threads > 1) {
|
||||
|
@ -18724,23 +18707,41 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
|
|||
{
|
||||
// update the number of threads from the actual number of threads that we got from OpenMP
|
||||
n_threads = omp_get_num_threads();
|
||||
workers[0].shared->n_threads = n_threads;
|
||||
workers[0].shared->current_chunk = n_threads;
|
||||
state_shared.n_threads = n_threads;
|
||||
}
|
||||
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
|
||||
|
||||
struct ggml_compute_state worker = {
|
||||
.thrd = 0,
|
||||
.ith = omp_get_thread_num(),
|
||||
.shared = &state_shared,
|
||||
};
|
||||
ggml_graph_compute_thread(&worker);
|
||||
}
|
||||
} else {
|
||||
ggml_graph_compute_thread(&workers[0]);
|
||||
struct ggml_compute_state worker = {
|
||||
.thrd = 0,
|
||||
.ith = 0,
|
||||
.shared = &state_shared,
|
||||
};
|
||||
ggml_graph_compute_thread(&worker);
|
||||
}
|
||||
#else
|
||||
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
|
||||
|
||||
for (int j = 0; j < n_threads; ++j) {
|
||||
workers[j] = (struct ggml_compute_state) {
|
||||
.thrd = 0,
|
||||
.ith = j,
|
||||
.shared = &state_shared,
|
||||
};
|
||||
}
|
||||
|
||||
// create thread pool
|
||||
if (n_threads > 1) {
|
||||
for (int j = 1; j < n_threads; ++j) {
|
||||
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
|
||||
GGML_ASSERT(rc == 0);
|
||||
UNUSED(rc);
|
||||
}
|
||||
}
|
||||
|
||||
// this is a work thread too
|
||||
ggml_graph_compute_thread(&workers[0]);
|
||||
|
@ -18754,58 +18755,11 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
|
|||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// don't leave affinity set on the main thread
|
||||
clear_numa_thread_affinity();
|
||||
|
||||
for (int j = 0; j < n_threads; j++) {
|
||||
if (workers[j].ec != GGML_STATUS_SUCCESS) {
|
||||
compute_status = workers[j].ec;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return compute_status;
|
||||
}
|
||||
|
||||
enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
||||
{
|
||||
GGML_ASSERT(cplan);
|
||||
GGML_ASSERT(cplan->n_threads > 0);
|
||||
|
||||
if (cplan->work_size > 0) {
|
||||
GGML_ASSERT(cplan->work_data);
|
||||
}
|
||||
}
|
||||
|
||||
int n_threads = cplan->n_threads;
|
||||
|
||||
#if defined(GGML_USE_OPENMP)
|
||||
n_threads = MIN(n_threads, omp_get_max_threads());
|
||||
#endif
|
||||
|
||||
struct ggml_compute_state_shared state_shared = {
|
||||
/*.cgraph =*/ cgraph,
|
||||
/*.cgraph_plan =*/ cplan,
|
||||
/*.n_threads =*/ n_threads,
|
||||
/*.n_barrier =*/ 0,
|
||||
/*.n_barrier_passed =*/ 0,
|
||||
/*.abort_callback =*/ NULL,
|
||||
/*.abort_callback_data =*/ NULL,
|
||||
/*.current_chunk =*/ 0,
|
||||
};
|
||||
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
|
||||
|
||||
for (int j = 0; j < n_threads; ++j) {
|
||||
workers[j] = (struct ggml_compute_state) {
|
||||
.thrd = 0,
|
||||
.ith = j,
|
||||
.shared = &state_shared,
|
||||
.ec = GGML_STATUS_SUCCESS,
|
||||
};
|
||||
}
|
||||
|
||||
enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
|
||||
|
||||
return compute_status;
|
||||
return state_shared.ec;
|
||||
}
|
||||
|
||||
enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue