move batch_allocr inside decode/encode_internal
This commit is contained in:
parent
bd697ca77d
commit
6ab116ac5a
1 changed files with 16 additions and 17 deletions
|
@ -17108,16 +17108,19 @@ static void llama_graph_compute(
|
|||
//
|
||||
static int llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch) {
|
||||
llama_batch inp_batch) {
|
||||
|
||||
lctx.is_encoding = false;
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
||||
if (n_tokens_all == 0) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||
llama_batch batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
@ -17422,17 +17425,19 @@ static int llama_decode_internal(
|
|||
//
|
||||
static int llama_encode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch) {
|
||||
llama_batch inp_batch) {
|
||||
|
||||
lctx.is_encoding = true;
|
||||
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
if (n_tokens == 0) {
|
||||
if (inp_batch.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
llama_batch_allocr batch_allocr(lctx, inp_batch);
|
||||
llama_batch batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
@ -21137,16 +21142,12 @@ struct llama_batch_allocr {
|
|||
std::vector<int8_t> logits;
|
||||
struct llama_batch batch;
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
|
||||
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
||||
batch = in_batch;
|
||||
if (batch.n_tokens == 0) {
|
||||
// llama_(de|en)code_internal will return an error in this case
|
||||
return;
|
||||
}
|
||||
if (!batch.pos) {
|
||||
// determine the last position in KV cache
|
||||
llama_pos last_pos = -1;
|
||||
for (const auto & cell : ctx->kv_self.cells) {
|
||||
for (const auto & cell : ctx.kv_self.cells) {
|
||||
if (cell.has_seq_id(batch_default_seq_id)) {
|
||||
last_pos = std::max(last_pos, cell.pos);
|
||||
}
|
||||
|
@ -21184,8 +21185,7 @@ struct llama_batch_allocr {
|
|||
int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
llama_batch_allocr batch_allocr(ctx, batch);
|
||||
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
|
||||
const int ret = llama_encode_internal(*ctx, batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
@ -21196,8 +21196,7 @@ int32_t llama_encode(
|
|||
int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
llama_batch_allocr batch_allocr(ctx, batch);
|
||||
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
|
||||
const int ret = llama_decode_internal(*ctx, batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue