threadpool: move all pause/resume logic into ggml
This commit is contained in:
parent
5d4c0a1327
commit
e3c2202049
6 changed files with 19 additions and 88 deletions
|
@ -1537,7 +1537,7 @@ int main(int argc, char ** argv) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
llama_attach_threadpool(ctx, threadpool);
|
||||
llama_attach_threadpool(ctx, threadpool, NULL);
|
||||
|
||||
// warmup run
|
||||
if (t.n_prompt > 0) {
|
||||
|
|
|
@ -240,11 +240,6 @@ int main(int argc, char ** argv) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
llama_attach_batch_threadpool(ctx, threadpool_batch);
|
||||
if (ctx_guidance) {
|
||||
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
|
||||
}
|
||||
|
||||
// Start the non-batch threadpool in the paused state
|
||||
tpp.paused = true;
|
||||
}
|
||||
|
@ -255,9 +250,9 @@ int main(int argc, char ** argv) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
llama_attach_threadpool(ctx, threadpool);
|
||||
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
|
||||
if (ctx_guidance) {
|
||||
llama_attach_threadpool(ctx_guidance, threadpool);
|
||||
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
|
||||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
|
|
|
@ -910,6 +910,11 @@ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_compute_th
|
|||
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
||||
|
||||
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
||||
|
||||
if (ctx->threadpool && ctx->threadpool != threadpool) {
|
||||
// already had a different threadpool, pause/suspend it before switching
|
||||
ggml_pause_threadpool(ctx->threadpool);
|
||||
}
|
||||
ctx->threadpool = threadpool;
|
||||
}
|
||||
|
||||
|
|
|
@ -19198,9 +19198,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
|
|||
state->pending = false;
|
||||
|
||||
ggml_graph_compute_thread(state);
|
||||
if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -431,16 +431,9 @@ extern "C" {
|
|||
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
||||
LLAMA_API void llama_attach_threadpool(
|
||||
struct llama_context * ctx,
|
||||
ggml_compute_threadpool_t threadpool);
|
||||
LLAMA_API void llama_attach_batch_threadpool(
|
||||
struct llama_context * ctx,
|
||||
ggml_compute_threadpool_t threadpool);
|
||||
ggml_compute_threadpool_t threadpool,
|
||||
ggml_compute_threadpool_t threadpool_batch);
|
||||
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||
LLAMA_API void llama_detach_batch_threadpool(struct llama_context * ctx);
|
||||
LLAMA_API void llama_detach_threadpools(struct llama_context * ctx);
|
||||
|
||||
// Pauses all attached threadpools
|
||||
LLAMA_API void llama_pause_threadpools(struct llama_context * ctx);
|
||||
|
||||
// Call once at the end of the program - currently only used for MPI
|
||||
LLAMA_API void llama_backend_free(void);
|
||||
|
|
|
@ -15523,39 +15523,6 @@ static void llama_graph_compute(
|
|||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||
}
|
||||
|
||||
// Optionally swaps the batch and single-tok threadpools.
|
||||
// Returns the number of threads, and if a valid threadpool exists, returns it too.
|
||||
static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
|
||||
llama_context & lctx,
|
||||
int32_t n_tokens) {
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
|
||||
ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool
|
||||
|
||||
// A batch threadpool without a non-batch threadpool isn't supported.
|
||||
GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);
|
||||
|
||||
if (lctx.threadpool_batch && lctx.threadpool) {
|
||||
// Switch between the 2 threadpools as needed
|
||||
if (n_tokens > 1) {
|
||||
ggml_pause_threadpool(lctx.threadpool);
|
||||
threadpool = lctx.threadpool_batch;
|
||||
n_threads = cparams.n_threads_batch;
|
||||
} else {
|
||||
ggml_pause_threadpool(lctx.threadpool_batch);
|
||||
threadpool = lctx.threadpool;
|
||||
n_threads = cparams.n_threads;
|
||||
}
|
||||
} else if (lctx.threadpool) {
|
||||
threadpool = lctx.threadpool;
|
||||
n_threads = cparams.n_threads;
|
||||
}
|
||||
return std::make_pair(n_threads, threadpool);
|
||||
}
|
||||
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
|
@ -15662,11 +15629,8 @@ static int llama_decode_internal(
|
|||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
std::pair<int32_t, ggml_compute_threadpool_t> threads =
|
||||
llama_swap_threadpools(lctx, n_tokens);
|
||||
|
||||
int n_threads = threads.first;
|
||||
ggml_compute_threadpool_t threadpool = threads.second;
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
|
@ -15906,11 +15870,9 @@ static int llama_encode_internal(
|
|||
lctx.inp_embd_enc = NULL;
|
||||
lctx.n_outputs = n_tokens;
|
||||
|
||||
std::pair<int32_t, ggml_compute_threadpool_t> threads =
|
||||
llama_swap_threadpools(lctx, n_tokens);
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
|
||||
int n_threads = threads.first;
|
||||
ggml_compute_threadpool_t threadpool = threads.second;
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
|
|||
|
||||
void llama_attach_threadpool(
|
||||
struct llama_context * ctx,
|
||||
ggml_compute_threadpool_t threadpool) {
|
||||
ctx->threadpool = threadpool;
|
||||
}
|
||||
|
||||
void llama_attach_batch_threadpool(
|
||||
struct llama_context * ctx,
|
||||
ggml_compute_threadpool_t threadpool,
|
||||
ggml_compute_threadpool_t threadpool_batch) {
|
||||
ctx->threadpool_batch = threadpool_batch;
|
||||
ctx->threadpool = threadpool;
|
||||
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
||||
}
|
||||
|
||||
void llama_detach_threadpool(struct llama_context * ctx) {
|
||||
ctx->threadpool = nullptr;
|
||||
}
|
||||
|
||||
void llama_detach_batch_threadpool(struct llama_context * ctx) {
|
||||
ctx->threadpool = nullptr;
|
||||
}
|
||||
|
||||
void llama_detach_threadpools(struct llama_context * ctx) {
|
||||
llama_detach_threadpool(ctx);
|
||||
llama_detach_batch_threadpool(ctx);
|
||||
}
|
||||
|
||||
void llama_pause_threadpools(struct llama_context * ctx) {
|
||||
if (ctx->threadpool) {
|
||||
ggml_pause_threadpool(ctx->threadpool);
|
||||
}
|
||||
if (ctx->threadpool_batch) {
|
||||
ggml_pause_threadpool(ctx->threadpool_batch);
|
||||
}
|
||||
ctx->threadpool = nullptr;
|
||||
ctx->threadpool_batch = nullptr;
|
||||
}
|
||||
|
||||
void llama_backend_free(void) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue