llama: propagating the results of graph_compute
to the user interface
This commit is contained in:
parent
6423c65aa8
commit
5e354e3ca2
1 changed files with 30 additions and 6 deletions
|
@ -17181,7 +17181,7 @@ static void llama_output_reorder(struct llama_context * ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_graph_compute(
|
static enum ggml_status llama_graph_compute(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
|
@ -17196,12 +17196,14 @@ static void llama_graph_compute(
|
||||||
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
|
auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
|
||||||
if (err != GGML_STATUS_SUCCESS) {
|
if (status != GGML_STATUS_SUCCESS) {
|
||||||
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
|
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
|
||||||
|
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
// decode a batch of tokens by evaluating the transformer
|
// decode a batch of tokens by evaluating the transformer
|
||||||
|
@ -17387,7 +17389,18 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
llama_set_inputs(lctx, ubatch);
|
llama_set_inputs(lctx, ubatch);
|
||||||
|
|
||||||
llama_graph_compute(lctx, gf, n_threads, threadpool);
|
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||||
|
switch (compute_status) {
|
||||||
|
case GGML_STATUS_SUCCESS:
|
||||||
|
break;
|
||||||
|
case GGML_STATUS_ABORTED:
|
||||||
|
return 2;
|
||||||
|
case GGML_STATUS_ALLOC_FAILED:
|
||||||
|
return -2;
|
||||||
|
case GGML_STATUS_FAILED:
|
||||||
|
default:
|
||||||
|
return -3;
|
||||||
|
}
|
||||||
|
|
||||||
// update the kv ring buffer
|
// update the kv ring buffer
|
||||||
{
|
{
|
||||||
|
@ -17624,7 +17637,18 @@ static int llama_encode_internal(
|
||||||
|
|
||||||
llama_set_inputs(lctx, ubatch);
|
llama_set_inputs(lctx, ubatch);
|
||||||
|
|
||||||
llama_graph_compute(lctx, gf, n_threads, threadpool);
|
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||||
|
switch (compute_status) {
|
||||||
|
case GGML_STATUS_SUCCESS:
|
||||||
|
break;
|
||||||
|
case GGML_STATUS_ABORTED:
|
||||||
|
return 2;
|
||||||
|
case GGML_STATUS_ALLOC_FAILED:
|
||||||
|
return -2;
|
||||||
|
case GGML_STATUS_FAILED:
|
||||||
|
default:
|
||||||
|
return -3;
|
||||||
|
}
|
||||||
|
|
||||||
// extract embeddings
|
// extract embeddings
|
||||||
if (embd) {
|
if (embd) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue