llama : simplify llama_encode_internal
This commit is contained in:
parent
03ab5dd67c
commit
88270a3613
1 changed files with 100 additions and 140 deletions
240
src/llama.cpp
240
src/llama.cpp
|
@ -2697,9 +2697,10 @@ struct llama_context {
|
||||||
|
|
||||||
// whether we are computing encoder output or decoder output
|
// whether we are computing encoder output or decoder output
|
||||||
bool is_encoding = false;
|
bool is_encoding = false;
|
||||||
|
|
||||||
// output of the encoder part of the encoder-decoder models
|
// output of the encoder part of the encoder-decoder models
|
||||||
std::vector<float> encoder_output;
|
std::vector<float> embd_enc;
|
||||||
std::vector<std::set<llama_seq_id> > encoder_output_seq_ids;
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
||||||
|
|
||||||
// memory buffers used to evaluate the model
|
// memory buffers used to evaluate the model
|
||||||
std::vector<uint8_t> buf_compute_meta;
|
std::vector<uint8_t> buf_compute_meta;
|
||||||
|
@ -7974,7 +7975,7 @@ struct llm_build_context {
|
||||||
n_tokens (batch.n_tokens),
|
n_tokens (batch.n_tokens),
|
||||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||||
n_outputs_enc (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd),
|
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
||||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||||
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
n_ctx_orig (cparams.n_ctx_orig_yarn),
|
||||||
flash_attn (cparams.flash_attn),
|
flash_attn (cparams.flash_attn),
|
||||||
|
@ -12716,7 +12717,7 @@ struct llm_build_context {
|
||||||
model.layers[il].ffn_down_enc, NULL, NULL,
|
model.layers[il].ffn_down_enc, NULL, NULL,
|
||||||
NULL,
|
NULL,
|
||||||
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
|
||||||
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
|
||||||
cb, il);
|
cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
|
@ -12742,8 +12743,8 @@ struct llm_build_context {
|
||||||
LLM_NORM_RMS, cb, -1);
|
LLM_NORM_RMS, cb, -1);
|
||||||
cb(cur, "result_norm", -1);
|
cb(cur, "result_norm", -1);
|
||||||
} else {
|
} else {
|
||||||
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
|
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
|
||||||
struct ggml_tensor * pos_buckets_dec = llm_build_pos_bucket(true);
|
struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
|
||||||
|
|
||||||
struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
|
||||||
struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
|
struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
|
||||||
|
@ -12794,7 +12795,7 @@ struct llm_build_context {
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
|
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
|
||||||
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_buckets_dec, attn_rel_b);
|
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
|
||||||
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
|
||||||
cb(kq_b, "kq_b", il);
|
cb(kq_b, "kq_b", il);
|
||||||
|
|
||||||
|
@ -12838,7 +12839,7 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
|
||||||
|
|
||||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||||
|
@ -13298,26 +13299,26 @@ static void llama_set_s_copy(llama_context & lctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t num_buckets, bool bidirectional) {
|
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||||
const int64_t max_distance = 128;
|
const int64_t max_distance = 128;
|
||||||
|
|
||||||
if (bidirectional) {
|
if (bidirectional) {
|
||||||
num_buckets >>= 1;
|
n_buckets >>= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t max_exact = num_buckets >> 1;
|
const int64_t max_exact = n_buckets >> 1;
|
||||||
|
|
||||||
int32_t relative_position = x - y;
|
int32_t relative_position = x - y;
|
||||||
int32_t relative_bucket = 0;
|
int32_t relative_bucket = 0;
|
||||||
if (bidirectional) {
|
if (bidirectional) {
|
||||||
relative_bucket += (relative_position > 0) * num_buckets;
|
relative_bucket += (relative_position > 0) * n_buckets;
|
||||||
relative_position = abs(relative_position);
|
relative_position = abs(relative_position);
|
||||||
} else {
|
} else {
|
||||||
relative_position = -std::min<int32_t>(relative_position, 0);
|
relative_position = -std::min<int32_t>(relative_position, 0);
|
||||||
}
|
}
|
||||||
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
|
||||||
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
|
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
|
||||||
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
|
||||||
return relative_bucket;
|
return relative_bucket;
|
||||||
}
|
}
|
||||||
|
@ -13634,13 +13635,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
|
||||||
if (!lctx.is_encoding && lctx.inp_embd_enc) {
|
if (!lctx.is_encoding && lctx.inp_embd_enc) {
|
||||||
assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
|
assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
|
||||||
assert(ggml_nelements(lctx.inp_embd_enc) == lctx.encoder_output.size());
|
assert(ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
|
||||||
|
|
||||||
ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.encoder_output.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
|
ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
|
if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
|
||||||
const int64_t n_encoder_output = lctx.encoder_output.size() / hparams.n_embd;
|
const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
|
||||||
|
@ -13649,21 +13650,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
|
||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
for (int i = 0; i < n_encoder_output; ++i) {
|
for (int i = 0; i < n_output_enc; ++i) {
|
||||||
float f = -INFINITY;
|
float f = -INFINITY;
|
||||||
for (int s = 0; s < batch.n_seq_id[j]; ++s) {
|
for (int s = 0; s < batch.n_seq_id[j]; ++s) {
|
||||||
const llama_seq_id seq_id = batch.seq_id[j][s];
|
const llama_seq_id seq_id = batch.seq_id[j][s];
|
||||||
if (lctx.encoder_output_seq_ids[i].find(seq_id) != lctx.encoder_output_seq_ids[i].end()) {
|
if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
|
||||||
f = 0.0f;
|
f = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data[h*(n_encoder_output*n_tokens) + j*n_encoder_output + i] = f;
|
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||||
for (int j = 0; j < n_encoder_output; ++j) {
|
for (int j = 0; j < n_output_enc; ++j) {
|
||||||
data[h*(n_encoder_output*n_tokens) + i*n_encoder_output + j] = -INFINITY;
|
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13809,6 +13810,7 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
const auto n_ubatch = cparams.n_ubatch;
|
const auto n_ubatch = cparams.n_ubatch;
|
||||||
|
|
||||||
|
// TODO: simplify or deprecate
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
std::vector<llama_seq_id *> seq_id_arr;
|
std::vector<llama_seq_id *> seq_id_arr;
|
||||||
|
@ -14083,12 +14085,13 @@ static int llama_decode_internal(
|
||||||
//
|
//
|
||||||
static int llama_encode_internal(
|
static int llama_encode_internal(
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
llama_batch batch_all) { // TODO: rename back to batch
|
llama_batch batch) {
|
||||||
|
|
||||||
lctx.is_encoding = true;
|
lctx.is_encoding = true;
|
||||||
const uint32_t n_tokens_all = batch_all.n_tokens;
|
|
||||||
|
|
||||||
if (n_tokens_all == 0) {
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
if (n_tokens == 0) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -14097,147 +14100,105 @@ static int llama_encode_internal(
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & cparams = lctx.cparams;
|
const auto & cparams = lctx.cparams;
|
||||||
|
|
||||||
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||||
|
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||||
GGML_ASSERT(cparams.n_ubatch >= n_tokens_all && "encoder requires n_ubatch >= n_tokens");
|
|
||||||
|
|
||||||
if (lctx.t_compute_start_us == 0) {
|
if (lctx.t_compute_start_us == 0) {
|
||||||
lctx.t_compute_start_us = ggml_time_us();
|
lctx.t_compute_start_us = ggml_time_us();
|
||||||
}
|
}
|
||||||
lctx.n_queued_tokens += n_tokens_all;
|
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
lctx.n_queued_tokens += n_tokens;
|
||||||
|
|
||||||
uint32_t n_outputs = 0;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
uint32_t n_outputs_prev = 0;
|
|
||||||
|
|
||||||
const auto n_ubatch = cparams.n_ubatch;
|
|
||||||
|
|
||||||
|
// TODO: simplify or deprecate
|
||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
std::vector<llama_seq_id *> seq_id_arr;
|
std::vector<llama_seq_id *> seq_id_arr;
|
||||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||||
|
|
||||||
n_outputs = n_tokens_all;
|
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
|
||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||||
return -2;
|
return -2;
|
||||||
};
|
};
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_outputs; ++i) {
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||||
lctx.output_ids[i] = i;
|
lctx.output_ids[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.inp_embd_enc = NULL;
|
lctx.inp_embd_enc = NULL;
|
||||||
|
lctx.n_outputs = n_tokens;
|
||||||
|
|
||||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||||
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
GGML_ASSERT(n_threads > 0);
|
||||||
llama_batch u_batch = {
|
|
||||||
/* .n_tokens = */ (int32_t) n_tokens,
|
|
||||||
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
|
|
||||||
/* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
|
|
||||||
/* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
|
|
||||||
/* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
|
|
||||||
/* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
|
|
||||||
/* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
|
|
||||||
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
|
|
||||||
/* .all_pos_1 = */ batch_all.all_pos_1,
|
|
||||||
/* .all_seq_id = */ batch_all.all_seq_id,
|
|
||||||
};
|
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// helpers for smoother batch API transition
|
||||||
{
|
// after deprecating the llama_eval calls, these will be removed
|
||||||
int32_t n_outputs_new = 0;
|
if (batch.pos == nullptr) {
|
||||||
|
pos.resize(n_tokens);
|
||||||
n_outputs_new = n_tokens;
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
|
||||||
// needs to happen before the graph is built
|
|
||||||
lctx.n_outputs = n_outputs_new;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
batch.pos = pos.data();
|
||||||
GGML_ASSERT(n_threads > 0);
|
|
||||||
|
|
||||||
// helpers for smoother batch API transition
|
|
||||||
// after deprecating the llama_eval calls, these will be removed
|
|
||||||
if (u_batch.pos == nullptr) {
|
|
||||||
pos.resize(n_tokens);
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
|
|
||||||
}
|
|
||||||
|
|
||||||
u_batch.pos = pos.data();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (u_batch.seq_id == nullptr) {
|
|
||||||
n_seq_id.resize(n_tokens);
|
|
||||||
seq_id.resize(n_tokens);
|
|
||||||
seq_id_arr.resize(n_tokens);
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
|
||||||
n_seq_id[i] = 1;
|
|
||||||
seq_id[i].resize(1);
|
|
||||||
seq_id[i][0] = u_batch.all_seq_id;
|
|
||||||
seq_id_arr[i] = seq_id[i].data();
|
|
||||||
}
|
|
||||||
|
|
||||||
u_batch.n_seq_id = n_seq_id.data();
|
|
||||||
u_batch.seq_id = seq_id_arr.data();
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched);
|
|
||||||
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
|
|
||||||
|
|
||||||
// the output is always the last tensor in the graph
|
|
||||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
|
||||||
|
|
||||||
// token or sequence embeddings
|
|
||||||
embd = gf->nodes[gf->n_nodes - 1];
|
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
|
||||||
|
|
||||||
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
|
||||||
|
|
||||||
llama_set_inputs(lctx, u_batch);
|
|
||||||
|
|
||||||
llama_graph_compute(lctx, gf, n_threads);
|
|
||||||
|
|
||||||
// extract embeddings
|
|
||||||
if (embd) {
|
|
||||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
|
||||||
GGML_ASSERT(backend_embd != nullptr);
|
|
||||||
|
|
||||||
// extract token embeddings
|
|
||||||
GGML_ASSERT(lctx.embd != nullptr);
|
|
||||||
const int32_t n_outputs_new = lctx.n_outputs;
|
|
||||||
lctx.encoder_output.resize((n_outputs_prev + n_outputs_new)*n_embd);
|
|
||||||
float * embd_out = lctx.encoder_output.data() + n_outputs_prev*n_embd;
|
|
||||||
|
|
||||||
if (n_outputs_new) {
|
|
||||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
|
||||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
|
|
||||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract output embeddings mask
|
|
||||||
lctx.encoder_output_seq_ids.resize(n_outputs_prev + n_outputs_new);
|
|
||||||
for (int i = 0; i < n_outputs_new; i++) {
|
|
||||||
for (int s = 0; s < u_batch.n_seq_id[i]; s++) {
|
|
||||||
llama_seq_id seq_id = u_batch.seq_id[i][s];
|
|
||||||
lctx.encoder_output_seq_ids[i].insert(seq_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n_outputs_prev += lctx.n_outputs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
if (batch.seq_id == nullptr) {
|
||||||
lctx.n_outputs = n_outputs;
|
n_seq_id.resize(n_tokens);
|
||||||
|
seq_id.resize(n_tokens);
|
||||||
|
seq_id_arr.resize(n_tokens);
|
||||||
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
|
n_seq_id[i] = 1;
|
||||||
|
seq_id[i].resize(1);
|
||||||
|
seq_id[i][0] = batch.all_seq_id;
|
||||||
|
seq_id_arr[i] = seq_id[i].data();
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.n_seq_id = n_seq_id.data();
|
||||||
|
batch.seq_id = seq_id_arr.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_sched_reset(lctx.sched);
|
||||||
|
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||||
|
|
||||||
|
ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
|
||||||
|
|
||||||
|
// the output embeddings after the final encoder normalization
|
||||||
|
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
|
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
||||||
|
|
||||||
|
llama_set_inputs(lctx, batch);
|
||||||
|
|
||||||
|
llama_graph_compute(lctx, gf, n_threads);
|
||||||
|
|
||||||
|
// extract embeddings
|
||||||
|
if (embd) {
|
||||||
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
|
||||||
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
|
// extract token embeddings
|
||||||
|
GGML_ASSERT(lctx.embd != nullptr);
|
||||||
|
|
||||||
|
lctx.embd_enc.resize(n_tokens*n_embd);
|
||||||
|
float * embd_out = lctx.embd_enc.data();
|
||||||
|
|
||||||
|
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
|
||||||
|
|
||||||
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
|
lctx.seq_ids_enc.resize(n_tokens);
|
||||||
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
|
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||||
|
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||||
|
lctx.seq_ids_enc[i].insert(seq_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
// overlap with device computation.
|
// overlap with device computation.
|
||||||
|
@ -14246,7 +14207,6 @@ static int llama_encode_internal(
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
||||||
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||||
auto & kv_self = lctx.kv_self;
|
auto & kv_self = lctx.kv_self;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue