threadpool: move all pause/resume logic into ggml
This commit is contained in:
parent
5d4c0a1327
commit
e3c2202049
6 changed files with 19 additions and 88 deletions
|
@ -1537,7 +1537,7 @@ int main(int argc, char ** argv) {
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_attach_threadpool(ctx, threadpool);
|
llama_attach_threadpool(ctx, threadpool, NULL);
|
||||||
|
|
||||||
// warmup run
|
// warmup run
|
||||||
if (t.n_prompt > 0) {
|
if (t.n_prompt > 0) {
|
||||||
|
|
|
@ -240,11 +240,6 @@ int main(int argc, char ** argv) {
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_attach_batch_threadpool(ctx, threadpool_batch);
|
|
||||||
if (ctx_guidance) {
|
|
||||||
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the non-batch threadpool in the paused state
|
// Start the non-batch threadpool in the paused state
|
||||||
tpp.paused = true;
|
tpp.paused = true;
|
||||||
}
|
}
|
||||||
|
@ -255,9 +250,9 @@ int main(int argc, char ** argv) {
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_attach_threadpool(ctx, threadpool);
|
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
|
||||||
if (ctx_guidance) {
|
if (ctx_guidance) {
|
||||||
llama_attach_threadpool(ctx_guidance, threadpool);
|
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_ctx_train = llama_n_ctx_train(model);
|
const int n_ctx_train = llama_n_ctx_train(model);
|
||||||
|
|
|
@ -910,6 +910,11 @@ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_compute_th
|
||||||
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
||||||
|
|
||||||
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
||||||
|
|
||||||
|
if (ctx->threadpool && ctx->threadpool != threadpool) {
|
||||||
|
// already had a different threadpool, pause/suspend it before switching
|
||||||
|
ggml_pause_threadpool(ctx->threadpool);
|
||||||
|
}
|
||||||
ctx->threadpool = threadpool;
|
ctx->threadpool = threadpool;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19198,9 +19198,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
|
||||||
state->pending = false;
|
state->pending = false;
|
||||||
|
|
||||||
ggml_graph_compute_thread(state);
|
ggml_graph_compute_thread(state);
|
||||||
if (state->threadpool->ec != GGML_STATUS_SUCCESS) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -431,16 +431,9 @@ extern "C" {
|
||||||
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
||||||
LLAMA_API void llama_attach_threadpool(
|
LLAMA_API void llama_attach_threadpool(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
ggml_compute_threadpool_t threadpool);
|
ggml_compute_threadpool_t threadpool,
|
||||||
LLAMA_API void llama_attach_batch_threadpool(
|
ggml_compute_threadpool_t threadpool_batch);
|
||||||
struct llama_context * ctx,
|
|
||||||
ggml_compute_threadpool_t threadpool);
|
|
||||||
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||||
LLAMA_API void llama_detach_batch_threadpool(struct llama_context * ctx);
|
|
||||||
LLAMA_API void llama_detach_threadpools(struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Pauses all attached threadpools
|
|
||||||
LLAMA_API void llama_pause_threadpools(struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Call once at the end of the program - currently only used for MPI
|
// Call once at the end of the program - currently only used for MPI
|
||||||
LLAMA_API void llama_backend_free(void);
|
LLAMA_API void llama_backend_free(void);
|
||||||
|
|
|
@ -15523,39 +15523,6 @@ static void llama_graph_compute(
|
||||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optionally swaps the batch and single-tok threadpools.
|
|
||||||
// Returns the number of threads, and if a valid threadpool exists, returns it too.
|
|
||||||
static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
|
|
||||||
llama_context & lctx,
|
|
||||||
int32_t n_tokens) {
|
|
||||||
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
|
||||||
|
|
||||||
ggml_compute_threadpool_t threadpool = nullptr; // nullptr -> disposable threadpool
|
|
||||||
|
|
||||||
// A batch threadpool without a non-batch threadpool isn't supported.
|
|
||||||
GGML_ASSERT(!lctx.threadpool_batch || lctx.threadpool);
|
|
||||||
|
|
||||||
if (lctx.threadpool_batch && lctx.threadpool) {
|
|
||||||
// Switch between the 2 threadpools as needed
|
|
||||||
if (n_tokens > 1) {
|
|
||||||
ggml_pause_threadpool(lctx.threadpool);
|
|
||||||
threadpool = lctx.threadpool_batch;
|
|
||||||
n_threads = cparams.n_threads_batch;
|
|
||||||
} else {
|
|
||||||
ggml_pause_threadpool(lctx.threadpool_batch);
|
|
||||||
threadpool = lctx.threadpool;
|
|
||||||
n_threads = cparams.n_threads;
|
|
||||||
}
|
|
||||||
} else if (lctx.threadpool) {
|
|
||||||
threadpool = lctx.threadpool;
|
|
||||||
n_threads = cparams.n_threads;
|
|
||||||
}
|
|
||||||
return std::make_pair(n_threads, threadpool);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// decode a batch of tokens by evaluating the transformer
|
// decode a batch of tokens by evaluating the transformer
|
||||||
//
|
//
|
||||||
// - lctx: llama context
|
// - lctx: llama context
|
||||||
|
@ -15662,11 +15629,8 @@ static int llama_decode_internal(
|
||||||
lctx.n_outputs = n_outputs_new;
|
lctx.n_outputs = n_outputs_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int32_t, ggml_compute_threadpool_t> threads =
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||||
llama_swap_threadpools(lctx, n_tokens);
|
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||||
|
|
||||||
int n_threads = threads.first;
|
|
||||||
ggml_compute_threadpool_t threadpool = threads.second;
|
|
||||||
|
|
||||||
GGML_ASSERT(n_threads > 0);
|
GGML_ASSERT(n_threads > 0);
|
||||||
|
|
||||||
|
@ -15906,11 +15870,9 @@ static int llama_encode_internal(
|
||||||
lctx.inp_embd_enc = NULL;
|
lctx.inp_embd_enc = NULL;
|
||||||
lctx.n_outputs = n_tokens;
|
lctx.n_outputs = n_tokens;
|
||||||
|
|
||||||
std::pair<int32_t, ggml_compute_threadpool_t> threads =
|
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||||
llama_swap_threadpools(lctx, n_tokens);
|
ggml_compute_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||||
|
|
||||||
int n_threads = threads.first;
|
|
||||||
ggml_compute_threadpool_t threadpool = threads.second;
|
|
||||||
GGML_ASSERT(n_threads > 0);
|
GGML_ASSERT(n_threads > 0);
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched);
|
ggml_backend_sched_reset(lctx.sched);
|
||||||
|
@ -17500,36 +17462,15 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
|
||||||
|
|
||||||
void llama_attach_threadpool(
|
void llama_attach_threadpool(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
ggml_compute_threadpool_t threadpool) {
|
ggml_compute_threadpool_t threadpool,
|
||||||
ctx->threadpool = threadpool;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_attach_batch_threadpool(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
ggml_compute_threadpool_t threadpool_batch) {
|
ggml_compute_threadpool_t threadpool_batch) {
|
||||||
ctx->threadpool_batch = threadpool_batch;
|
ctx->threadpool = threadpool;
|
||||||
|
ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_detach_threadpool(struct llama_context * ctx) {
|
void llama_detach_threadpool(struct llama_context * ctx) {
|
||||||
ctx->threadpool = nullptr;
|
ctx->threadpool = nullptr;
|
||||||
}
|
ctx->threadpool_batch = nullptr;
|
||||||
|
|
||||||
void llama_detach_batch_threadpool(struct llama_context * ctx) {
|
|
||||||
ctx->threadpool = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_detach_threadpools(struct llama_context * ctx) {
|
|
||||||
llama_detach_threadpool(ctx);
|
|
||||||
llama_detach_batch_threadpool(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_pause_threadpools(struct llama_context * ctx) {
|
|
||||||
if (ctx->threadpool) {
|
|
||||||
ggml_pause_threadpool(ctx->threadpool);
|
|
||||||
}
|
|
||||||
if (ctx->threadpool_batch) {
|
|
||||||
ggml_pause_threadpool(ctx->threadpool_batch);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_backend_free(void) {
|
void llama_backend_free(void) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue