move batch_allocr inside decode/encode_internal

This commit is contained in:
Xuan Son Nguyen 2024-10-22 13:01:22 +02:00
parent bd697ca77d
commit 6ab116ac5a

View file

@ -17108,16 +17108,19 @@ static void llama_graph_compute(
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch) {
llama_batch inp_batch) {
lctx.is_encoding = false;
const uint32_t n_tokens_all = batch.n_tokens;
if (n_tokens_all == 0) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
llama_batch_allocr batch_allocr(lctx, inp_batch);
llama_batch batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
@ -17422,17 +17425,19 @@ static int llama_decode_internal(
//
static int llama_encode_internal(
llama_context & lctx,
llama_batch batch) {
llama_batch inp_batch) {
lctx.is_encoding = true;
const uint32_t n_tokens = batch.n_tokens;
if (n_tokens == 0) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
llama_batch_allocr batch_allocr(lctx, inp_batch);
llama_batch batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
@ -21137,16 +21142,12 @@ struct llama_batch_allocr {
std::vector<int8_t> logits;
struct llama_batch batch;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
batch = in_batch;
if (batch.n_tokens == 0) {
// llama_(de|en)code_internal will return an error in this case
return;
}
if (!batch.pos) {
// determine the last position in KV cache
llama_pos last_pos = -1;
for (const auto & cell : ctx->kv_self.cells) {
for (const auto & cell : ctx.kv_self.cells) {
if (cell.has_seq_id(batch_default_seq_id)) {
last_pos = std::max(last_pos, cell.pos);
}
@ -21184,8 +21185,7 @@ struct llama_batch_allocr {
int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
const int ret = llama_encode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
@ -21196,8 +21196,7 @@ int32_t llama_encode(
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
const int ret = llama_decode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}