threadpool: move all pause/resume logic into ggml

This commit is contained in:
Max Krasnyansky 2024-08-27 13:19:45 -07:00
parent 5d4c0a1327
commit e3c2202049
6 changed files with 19 additions and 88 deletions

View file

@ -1537,7 +1537,7 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
llama_attach_threadpool(ctx, threadpool); llama_attach_threadpool(ctx, threadpool, NULL);
// warmup run // warmup run
if (t.n_prompt > 0) { if (t.n_prompt > 0) {

View file

@ -240,11 +240,6 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
llama_attach_batch_threadpool(ctx, threadpool_batch);
if (ctx_guidance) {
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
}
// Start the non-batch threadpool in the paused state // Start the non-batch threadpool in the paused state
tpp.paused = true; tpp.paused = true;
} }
@ -255,9 +250,9 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
llama_attach_threadpool(ctx, threadpool); llama_attach_threadpool(ctx, threadpool, threadpool_batch);
if (ctx_guidance) { if (ctx_guidance) {
llama_attach_threadpool(ctx_guidance, threadpool); llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
} }
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);

View file

@ -910,6 +910,11 @@ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_compute_th
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
if (ctx->threadpool && ctx->threadpool != threadpool) {
// already had a different threadpool, pause/suspend it before switching
ggml_pause_threadpool(ctx->threadpool);
}
ctx->threadpool = threadpool; ctx->threadpool = threadpool;
} }

View file

@ -19198,9 +19198,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
state->pending = false; state->pending = false;
ggml_graph_compute_thread(state); ggml_graph_compute_thread(state);
if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
break;
}
} }
} }

View file

@ -431,16 +431,9 @@ extern "C" {
// Optional: an auto threadpool gets created in ggml if not passed explicitly // Optional: an auto threadpool gets created in ggml if not passed explicitly
LLAMA_API void llama_attach_threadpool( LLAMA_API void llama_attach_threadpool(
struct llama_context * ctx, struct llama_context * ctx,
ggml_compute_threadpool_t threadpool); ggml_compute_threadpool_t threadpool,
LLAMA_API void llama_attach_batch_threadpool( ggml_compute_threadpool_t threadpool_batch);
struct llama_context * ctx,
ggml_compute_threadpool_t threadpool);
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
LLAMA_API void llama_detach_batch_threadpool(struct llama_context * ctx);
LLAMA_API void llama_detach_threadpools(struct llama_context * ctx);
// Pauses all attached threadpools
LLAMA_API void llama_pause_threadpools(struct llama_context * ctx);
// Call once at the end of the program - currently only used for MPI // Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_backend_free(void); LLAMA_API void llama_backend_free(void);

View file

@ -15523,39 +15523,6 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
} }
// Optionally swaps the batch and single-tok threadpools.
// Returns the number of threads, and if a valid threadpool exists, returns it too.
static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
llama_context & lctx,
int32_t n_tokens) {
const auto & cparams = lctx.cparams;
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool
// A batch threadpool without a non-batch threadpool isn't supported.
GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);
if (lctx.threadpool_batch && lctx.threadpool) {
// Switch between the 2 threadpools as needed
if (n_tokens > 1) {
ggml_pause_threadpool(lctx.threadpool);
threadpool = lctx.threadpool_batch;
n_threads = cparams.n_threads_batch;
} else {
ggml_pause_threadpool(lctx.threadpool_batch);
threadpool = lctx.threadpool;
n_threads = cparams.n_threads;
}
} else if (lctx.threadpool) {
threadpool = lctx.threadpool;
n_threads = cparams.n_threads;
}
return std::make_pair(n_threads, threadpool);
}
// decode a batch of tokens by evaluating the transformer // decode a batch of tokens by evaluating the transformer
// //
// - lctx: llama context // - lctx: llama context
@ -15662,11 +15629,8 @@ static int llama_decode_internal(
lctx.n_outputs = n_outputs_new; lctx.n_outputs = n_outputs_new;
} }
std::pair<int32_t, ggml_compute_threadpool_t> threads = int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
llama_swap_threadpools(lctx, n_tokens); ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
int n_threads = threads.first;
ggml_compute_threadpool_t threadpool = threads.second;
GGML_ASSERT(n_threads > 0); GGML_ASSERT(n_threads > 0);
@ -15906,11 +15870,9 @@ static int llama_encode_internal(
lctx.inp_embd_enc = NULL; lctx.inp_embd_enc = NULL;
lctx.n_outputs = n_tokens; lctx.n_outputs = n_tokens;
std::pair<int32_t, ggml_compute_threadpool_t> threads = int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
llama_swap_threadpools(lctx, n_tokens); ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
int n_threads = threads.first;
ggml_compute_threadpool_t threadpool = threads.second;
GGML_ASSERT(n_threads > 0); GGML_ASSERT(n_threads > 0);
ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_reset(lctx.sched);
@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
void llama_attach_threadpool( void llama_attach_threadpool(
struct llama_context * ctx, struct llama_context * ctx,
ggml_compute_threadpool_t threadpool) { ggml_compute_threadpool_t threadpool,
ctx->threadpool = threadpool;
}
void llama_attach_batch_threadpool(
struct llama_context * ctx,
ggml_compute_threadpool_t threadpool_batch) { ggml_compute_threadpool_t threadpool_batch) {
ctx->threadpool_batch = threadpool_batch; ctx->threadpool = threadpool;
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
} }
void llama_detach_threadpool(struct llama_context * ctx) { void llama_detach_threadpool(struct llama_context * ctx) {
ctx->threadpool = nullptr; ctx->threadpool = nullptr;
} ctx->threadpool_batch = nullptr;
void llama_detach_batch_threadpool(struct llama_context * ctx) {
ctx->threadpool = nullptr;
}
void llama_detach_threadpools(struct llama_context * ctx) {
llama_detach_threadpool(ctx);
llama_detach_batch_threadpool(ctx);
}
void llama_pause_threadpools(struct llama_context * ctx) {
if (ctx->threadpool) {
ggml_pause_threadpool(ctx->threadpool);
}
if (ctx->threadpool_batch) {
ggml_pause_threadpool(ctx->threadpool_batch);
}
} }
void llama_backend_free(void) { void llama_backend_free(void) {