llama : return enum for llama_decode and llama_encode

This commit is contained in:
Xuan Son Nguyen 2024-09-11 15:37:22 +02:00
parent 0996c5597f
commit 3dbd2eeb34
2 changed files with 35 additions and 17 deletions

View file

@ -253,6 +253,21 @@ extern "C" {
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
enum llama_decode_result {
LLAMA_DECODE_RESULT_OK = 0,
LLAMA_DECODE_RESULT_ERR_ALLOC_KV = 1,
LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT = -1,
LLAMA_DECODE_RESULT_INVALID_BATCH = -2,
};
enum llama_encode_result {
LLAMA_ENCODE_RESULT_OK = 0,
LLAMA_ENCODE_RESULT_ERR_ALLOC_KV = 1,
LLAMA_ENCODE_RESULT_ERR_NO_ENCODER = 2,
LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT = -1,
LLAMA_ENCODE_RESULT_INVALID_BATCH = -2,
};
enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_TYPE_INT,
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
@ -801,7 +816,7 @@ extern "C" {
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error
LLAMA_API int32_t llama_encode(
LLAMA_API enum llama_encode_result llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
@ -809,7 +824,7 @@ extern "C" {
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error
LLAMA_API int32_t llama_decode(
LLAMA_API enum llama_decode_result llama_decode(
struct llama_context * ctx,
struct llama_batch batch);

View file

@ -16064,7 +16064,7 @@ static void llama_graph_compute(
// return positive int on warning
// return negative int on error
//
static int llama_decode_internal(
static enum llama_decode_result llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
@ -16073,13 +16073,13 @@ static int llama_decode_internal(
if (n_tokens_all == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
return -1;
return LLAMA_DECODE_RESULT_INVALID_BATCH;
}
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= lctx.model.vocab.n_vocab) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
return -1;
return LLAMA_DECODE_RESULT_INVALID_BATCH;
}
}
@ -16132,7 +16132,7 @@ static int llama_decode_internal(
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
return LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT;
};
while (lctx.sbatch.n_tokens > 0) {
@ -16184,7 +16184,7 @@ static int llama_decode_internal(
}
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
return 1;
return LLAMA_DECODE_RESULT_ERR_ALLOC_KV;
}
if (!kv_self.recurrent) {
@ -16350,7 +16350,7 @@ static int llama_decode_internal(
// overlap with device computation.
ggml_backend_sched_reset(lctx.sched);
return 0;
return LLAMA_DECODE_RESULT_OK;
}
// encode a batch of tokens by evaluating the encoder part of the transformer
@ -16362,9 +16362,12 @@ static int llama_decode_internal(
// return positive int on warning
// return negative int on error
//
static int llama_encode_internal(
static enum llama_encode_result llama_encode_internal(
llama_context & lctx,
llama_batch batch) {
if (!llama_model_has_encoder(&lctx.model)) {
return LLAMA_ENCODE_RESULT_ERR_NO_ENCODER;
}
lctx.is_encoding = true;
@ -16372,13 +16375,13 @@ static int llama_encode_internal(
if (n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
return -1;
return LLAMA_ENCODE_RESULT_INVALID_BATCH;
}
for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= lctx.model.vocab.n_vocab) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
return -1;
return LLAMA_ENCODE_RESULT_INVALID_BATCH;
}
}
@ -16406,7 +16409,7 @@ static int llama_encode_internal(
// reserve output buffer
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
return LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT;
};
for (uint32_t i = 0; i < n_tokens; ++i) {
@ -16516,7 +16519,7 @@ static int llama_encode_internal(
// overlap with device computation.
ggml_backend_sched_reset(lctx.sched);
return 0;
return LLAMA_ENCODE_RESULT_OK;
}
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
@ -20038,10 +20041,10 @@ void llama_batch_free(struct llama_batch batch) {
if (batch.logits) free(batch.logits);
}
int32_t llama_encode(
enum llama_encode_result llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = llama_encode_internal(*ctx, batch);
const enum llama_encode_result ret = llama_encode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
@ -20049,10 +20052,10 @@ int32_t llama_encode(
return ret;
}
int32_t llama_decode(
enum llama_decode_result llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = llama_decode_internal(*ctx, batch);
const enum llama_decode_result ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}