check abort_callback on main thread only
This commit is contained in:
parent
d27f26ea0c
commit
486d06106c
1 changed files with 92 additions and 138 deletions
230
ggml.c
230
ggml.c
|
@ -1744,13 +1744,14 @@ struct ggml_compute_state_shared {
|
||||||
void * abort_callback_data;
|
void * abort_callback_data;
|
||||||
|
|
||||||
atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
|
atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
|
||||||
|
|
||||||
|
enum ggml_status ec;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_compute_state {
|
struct ggml_compute_state {
|
||||||
ggml_thread_t thrd;
|
ggml_thread_t thrd;
|
||||||
int ith;
|
int ith;
|
||||||
struct ggml_compute_state_shared * shared;
|
struct ggml_compute_state_shared * shared;
|
||||||
enum ggml_status ec;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_compute_params {
|
struct ggml_compute_params {
|
||||||
|
@ -3001,7 +3002,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(numa_flag);
|
UNUSED(numa_flag);
|
||||||
// TODO
|
// TODO
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -15980,7 +15981,7 @@ static void ggml_compute_forward_unary(
|
||||||
static void ggml_compute_forward_get_rel_pos_f16(
|
static void ggml_compute_forward_get_rel_pos_f16(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_UNUSED(params);
|
UNUSED(params);
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
@ -18317,8 +18318,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_UNARY_OP_ELU:
|
case GGML_UNARY_OP_ELU:
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
{
|
{
|
||||||
n_tasks = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
|
@ -18341,24 +18342,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
// FIXME: the cost of launching additional threads decreases performance with GPU offloading
|
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
|
||||||
|
// decreases performance with GPU offloading
|
||||||
//n_tasks = n_threads;
|
//n_tasks = n_threads;
|
||||||
n_tasks = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
|
@ -18390,14 +18383,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
{
|
{
|
||||||
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
|
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
{
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
|
@ -18408,33 +18395,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
n_tasks = 1;
|
n_tasks = 1;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_FLASH_ATTN_BACK:
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
|
@ -18482,9 +18448,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
{
|
|
||||||
n_tasks = n_threads;
|
|
||||||
} break;
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
|
@ -18514,37 +18477,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
return n_tasks;
|
return n_tasks;
|
||||||
}
|
}
|
||||||
|
|
||||||
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|
||||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
|
||||||
|
|
||||||
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
|
||||||
const struct ggml_cplan * cplan = state->shared->cplan;
|
|
||||||
|
|
||||||
set_numa_thread_affinity(state->ith);
|
|
||||||
|
|
||||||
struct ggml_compute_params params = {
|
|
||||||
/*.ith =*/ state->ith,
|
|
||||||
/*.nth =*/ state->shared->n_threads,
|
|
||||||
/*.wsize =*/ cplan->work_size,
|
|
||||||
/*.wdata =*/ cplan->work_data,
|
|
||||||
/*.shared=*/ state->shared,
|
|
||||||
};
|
|
||||||
|
|
||||||
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
|
|
||||||
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
|
||||||
state->ec = GGML_STATUS_ABORTED;
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
|
||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
|
||||||
|
|
||||||
ggml_barrier(state->shared);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
|
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
if (n_threads <= 0) {
|
if (n_threads <= 0) {
|
||||||
n_threads = GGML_DEFAULT_N_THREADS;
|
n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
|
@ -18713,8 +18645,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||||
return cplan;
|
return cplan;
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) {
|
static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
|
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||||
|
|
||||||
|
const struct ggml_cgraph * cgraph = state->shared->cgraph;
|
||||||
|
const struct ggml_cplan * cplan = state->shared->cplan;
|
||||||
|
|
||||||
|
set_numa_thread_affinity(state->ith);
|
||||||
|
|
||||||
|
struct ggml_compute_params params = {
|
||||||
|
/*.ith =*/ state->ith,
|
||||||
|
/*.nth =*/ state->shared->n_threads,
|
||||||
|
/*.wsize =*/ cplan->work_size,
|
||||||
|
/*.wdata =*/ cplan->work_data,
|
||||||
|
/*.shared=*/ state->shared,
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
|
||||||
|
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||||
|
|
||||||
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
|
if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
|
||||||
|
state->shared->ec = GGML_STATUS_ABORTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_barrier(state->shared);
|
||||||
|
|
||||||
|
if (state->shared->ec != GGML_STATUS_SUCCESS) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
||||||
|
GGML_ASSERT(cplan);
|
||||||
|
GGML_ASSERT(cplan->n_threads > 0);
|
||||||
|
GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
|
||||||
|
|
||||||
|
int n_threads = cplan->n_threads;
|
||||||
|
|
||||||
|
struct ggml_compute_state_shared state_shared = {
|
||||||
|
/*.cgraph =*/ cgraph,
|
||||||
|
/*.cgraph_plan =*/ cplan,
|
||||||
|
/*.n_threads =*/ n_threads,
|
||||||
|
/*.n_barrier =*/ 0,
|
||||||
|
/*.n_barrier_passed =*/ 0,
|
||||||
|
/*.abort_callback =*/ NULL,
|
||||||
|
/*.abort_callback_data =*/ NULL,
|
||||||
|
/*.current_chunk =*/ 0,
|
||||||
|
/*.ec =*/ GGML_STATUS_SUCCESS,
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef GGML_USE_OPENMP
|
#ifdef GGML_USE_OPENMP
|
||||||
if (n_threads > 1) {
|
if (n_threads > 1) {
|
||||||
|
@ -18724,22 +18707,40 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
|
||||||
{
|
{
|
||||||
// update the number of threads from the actual number of threads that we got from OpenMP
|
// update the number of threads from the actual number of threads that we got from OpenMP
|
||||||
n_threads = omp_get_num_threads();
|
n_threads = omp_get_num_threads();
|
||||||
workers[0].shared->n_threads = n_threads;
|
state_shared.n_threads = n_threads;
|
||||||
workers[0].shared->current_chunk = n_threads;
|
|
||||||
}
|
}
|
||||||
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
|
|
||||||
|
struct ggml_compute_state worker = {
|
||||||
|
.thrd = 0,
|
||||||
|
.ith = omp_get_thread_num(),
|
||||||
|
.shared = &state_shared,
|
||||||
|
};
|
||||||
|
ggml_graph_compute_thread(&worker);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ggml_graph_compute_thread(&workers[0]);
|
struct ggml_compute_state worker = {
|
||||||
|
.thrd = 0,
|
||||||
|
.ith = 0,
|
||||||
|
.shared = &state_shared,
|
||||||
|
};
|
||||||
|
ggml_graph_compute_thread(&worker);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
|
||||||
|
|
||||||
|
for (int j = 0; j < n_threads; ++j) {
|
||||||
|
workers[j] = (struct ggml_compute_state) {
|
||||||
|
.thrd = 0,
|
||||||
|
.ith = j,
|
||||||
|
.shared = &state_shared,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// create thread pool
|
// create thread pool
|
||||||
if (n_threads > 1) {
|
for (int j = 1; j < n_threads; ++j) {
|
||||||
for (int j = 1; j < n_threads; ++j) {
|
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
|
||||||
const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
|
GGML_ASSERT(rc == 0);
|
||||||
GGML_ASSERT(rc == 0);
|
UNUSED(rc);
|
||||||
UNUSED(rc);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is a work thread too
|
// this is a work thread too
|
||||||
|
@ -18754,58 +18755,11 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// don't leave affinity set on the main thread
|
// don't leave affinity set on the main thread
|
||||||
clear_numa_thread_affinity();
|
clear_numa_thread_affinity();
|
||||||
|
|
||||||
for (int j = 0; j < n_threads; j++) {
|
return state_shared.ec;
|
||||||
if (workers[j].ec != GGML_STATUS_SUCCESS) {
|
|
||||||
compute_status = workers[j].ec;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return compute_status;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
|
||||||
{
|
|
||||||
GGML_ASSERT(cplan);
|
|
||||||
GGML_ASSERT(cplan->n_threads > 0);
|
|
||||||
|
|
||||||
if (cplan->work_size > 0) {
|
|
||||||
GGML_ASSERT(cplan->work_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int n_threads = cplan->n_threads;
|
|
||||||
|
|
||||||
#if defined(GGML_USE_OPENMP)
|
|
||||||
n_threads = MIN(n_threads, omp_get_max_threads());
|
|
||||||
#endif
|
|
||||||
|
|
||||||
struct ggml_compute_state_shared state_shared = {
|
|
||||||
/*.cgraph =*/ cgraph,
|
|
||||||
/*.cgraph_plan =*/ cplan,
|
|
||||||
/*.n_threads =*/ n_threads,
|
|
||||||
/*.n_barrier =*/ 0,
|
|
||||||
/*.n_barrier_passed =*/ 0,
|
|
||||||
/*.abort_callback =*/ NULL,
|
|
||||||
/*.abort_callback_data =*/ NULL,
|
|
||||||
/*.current_chunk =*/ 0,
|
|
||||||
};
|
|
||||||
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
|
|
||||||
|
|
||||||
for (int j = 0; j < n_threads; ++j) {
|
|
||||||
workers[j] = (struct ggml_compute_state) {
|
|
||||||
.thrd = 0,
|
|
||||||
.ith = j,
|
|
||||||
.shared = &state_shared,
|
|
||||||
.ec = GGML_STATUS_SUCCESS,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads);
|
|
||||||
|
|
||||||
return compute_status;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
|
enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue