llama : return enum for llama_decode
and llama_encode
This commit is contained in:
parent
0996c5597f
commit
3dbd2eeb34
2 changed files with 35 additions and 17 deletions
|
@ -253,6 +253,21 @@ extern "C" {
|
|||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||
} llama_batch;
|
||||
|
||||
enum llama_decode_result {
|
||||
LLAMA_DECODE_RESULT_OK = 0,
|
||||
LLAMA_DECODE_RESULT_ERR_ALLOC_KV = 1,
|
||||
LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT = -1,
|
||||
LLAMA_DECODE_RESULT_INVALID_BATCH = -2,
|
||||
};
|
||||
|
||||
enum llama_encode_result {
|
||||
LLAMA_ENCODE_RESULT_OK = 0,
|
||||
LLAMA_ENCODE_RESULT_ERR_ALLOC_KV = 1,
|
||||
LLAMA_ENCODE_RESULT_ERR_NO_ENCODER = 2,
|
||||
LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT = -1,
|
||||
LLAMA_ENCODE_RESULT_INVALID_BATCH = -2,
|
||||
};
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
LLAMA_KV_OVERRIDE_TYPE_INT,
|
||||
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
|
||||
|
@ -801,7 +816,7 @@ extern "C" {
|
|||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error
|
||||
LLAMA_API int32_t llama_encode(
|
||||
LLAMA_API enum llama_encode_result llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
|
@ -809,7 +824,7 @@ extern "C" {
|
|||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error
|
||||
LLAMA_API int32_t llama_decode(
|
||||
LLAMA_API enum llama_decode_result llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
|
|
|
@ -16064,7 +16064,7 @@ static void llama_graph_compute(
|
|||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
static int llama_decode_internal(
|
||||
static enum llama_decode_result llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch_all) { // TODO: rename back to batch
|
||||
|
||||
|
@ -16073,13 +16073,13 @@ static int llama_decode_internal(
|
|||
|
||||
if (n_tokens_all == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||
return -1;
|
||||
return LLAMA_DECODE_RESULT_INVALID_BATCH;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= lctx.model.vocab.n_vocab) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
|
||||
return -1;
|
||||
return LLAMA_DECODE_RESULT_INVALID_BATCH;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16132,7 +16132,7 @@ static int llama_decode_internal(
|
|||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
return -2;
|
||||
return LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT;
|
||||
};
|
||||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
|
@ -16184,7 +16184,7 @@ static int llama_decode_internal(
|
|||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
||||
return 1;
|
||||
return LLAMA_DECODE_RESULT_ERR_ALLOC_KV;
|
||||
}
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
|
@ -16350,7 +16350,7 @@ static int llama_decode_internal(
|
|||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
||||
return 0;
|
||||
return LLAMA_DECODE_RESULT_OK;
|
||||
}
|
||||
|
||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||
|
@ -16362,9 +16362,12 @@ static int llama_decode_internal(
|
|||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
static int llama_encode_internal(
|
||||
static enum llama_encode_result llama_encode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch) {
|
||||
if (!llama_model_has_encoder(&lctx.model)) {
|
||||
return LLAMA_ENCODE_RESULT_ERR_NO_ENCODER;
|
||||
}
|
||||
|
||||
lctx.is_encoding = true;
|
||||
|
||||
|
@ -16372,13 +16375,13 @@ static int llama_encode_internal(
|
|||
|
||||
if (n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||
return -1;
|
||||
return LLAMA_ENCODE_RESULT_INVALID_BATCH;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= lctx.model.vocab.n_vocab) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
return LLAMA_ENCODE_RESULT_INVALID_BATCH;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16406,7 +16409,7 @@ static int llama_encode_internal(
|
|||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||
return -2;
|
||||
return LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT;
|
||||
};
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
|
@ -16516,7 +16519,7 @@ static int llama_encode_internal(
|
|||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
||||
return 0;
|
||||
return LLAMA_ENCODE_RESULT_OK;
|
||||
}
|
||||
|
||||
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
||||
|
@ -20038,10 +20041,10 @@ void llama_batch_free(struct llama_batch batch) {
|
|||
if (batch.logits) free(batch.logits);
|
||||
}
|
||||
|
||||
int32_t llama_encode(
|
||||
enum llama_encode_result llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
const int ret = llama_encode_internal(*ctx, batch);
|
||||
const enum llama_encode_result ret = llama_encode_internal(*ctx, batch);
|
||||
if (ret < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
@ -20049,10 +20052,10 @@ int32_t llama_encode(
|
|||
return ret;
|
||||
}
|
||||
|
||||
int32_t llama_decode(
|
||||
enum llama_decode_result llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
const int ret = llama_decode_internal(*ctx, batch);
|
||||
const enum llama_decode_result ret = llama_decode_internal(*ctx, batch);
|
||||
if (ret < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue