llama: propagating the results of graph_compute to the user interface

This commit is contained in:
Michael Podvitskiy 2024-09-17 21:43:01 +02:00
parent 6423c65aa8
commit 5e354e3ca2

View file

@ -17181,7 +17181,7 @@ static void llama_output_reorder(struct llama_context * ctx) {
} }
} }
static void llama_graph_compute( static enum ggml_status llama_graph_compute(
llama_context & lctx, llama_context & lctx,
ggml_cgraph * gf, ggml_cgraph * gf,
int n_threads, int n_threads,
@ -17196,12 +17196,14 @@ static void llama_graph_compute(
set_n_threads_fn.second(set_n_threads_fn.first, n_threads); set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
} }
auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf); auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
if (err != GGML_STATUS_SUCCESS) { if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err); LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
} }
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
return status;
} }
// decode a batch of tokens by evaluating the transformer // decode a batch of tokens by evaluating the transformer
@ -17387,7 +17389,18 @@ static int llama_decode_internal(
llama_set_inputs(lctx, ubatch); llama_set_inputs(lctx, ubatch);
llama_graph_compute(lctx, gf, n_threads, threadpool); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
// update the kv ring buffer // update the kv ring buffer
{ {
@ -17624,7 +17637,18 @@ static int llama_encode_internal(
llama_set_inputs(lctx, ubatch); llama_set_inputs(lctx, ubatch);
llama_graph_compute(lctx, gf, n_threads, threadpool); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
// extract embeddings // extract embeddings
if (embd) { if (embd) {