llama : simplify llama_encode_internal

This commit is contained in:
Georgi Gerganov 2024-07-04 12:10:32 +03:00
parent 03ab5dd67c
commit 88270a3613
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2697,9 +2697,10 @@ struct llama_context {
// whether we are computing encoder output or decoder output
bool is_encoding = false;
// output of the encoder part of the encoder-decoder models
std::vector<float> encoder_output;
std::vector<std::set<llama_seq_id> > encoder_output_seq_ids;
std::vector<float> embd_enc;
std::vector<std::set<llama_seq_id>> seq_ids_enc;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
@ -7974,7 +7975,7 @@ struct llm_build_context {
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
n_outputs_enc (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd),
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
@ -12716,7 +12717,7 @@ struct llm_build_context {
model.layers[il].ffn_down_enc, NULL, NULL,
NULL,
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
cb, il);
cb(cur, "ffn_out", il);
}
@ -12742,8 +12743,8 @@ struct llm_build_context {
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
} else {
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
struct ggml_tensor * pos_buckets_dec = llm_build_pos_bucket(true);
struct ggml_tensor * embd_enc = llm_build_inp_embd_enc();
struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask();
struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
@ -12794,7 +12795,7 @@ struct llm_build_context {
cb(kq, "kq", il);
struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_buckets_dec, attn_rel_b);
struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
cb(kq_b, "kq_b", il);
@ -12838,7 +12839,7 @@ struct llm_build_context {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
@ -13298,26 +13299,26 @@ static void llama_set_s_copy(llama_context & lctx) {
}
}
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t num_buckets, bool bidirectional) {
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
const int64_t max_distance = 128;
if (bidirectional) {
num_buckets >>= 1;
n_buckets >>= 1;
}
const int64_t max_exact = num_buckets >> 1;
const int64_t max_exact = n_buckets >> 1;
int32_t relative_position = x - y;
int32_t relative_bucket = 0;
if (bidirectional) {
relative_bucket += (relative_position > 0) * num_buckets;
relative_bucket += (relative_position > 0) * n_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
return relative_bucket;
}
@ -13634,13 +13635,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
if (!lctx.is_encoding && lctx.inp_embd_enc) {
assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
assert(ggml_nelements(lctx.inp_embd_enc) == lctx.encoder_output.size());
assert(ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.encoder_output.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
}
if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
const int64_t n_encoder_output = lctx.encoder_output.size() / hparams.n_embd;
const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
@ -13649,21 +13650,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_encoder_output; ++i) {
for (int i = 0; i < n_output_enc; ++i) {
float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[j]; ++s) {
const llama_seq_id seq_id = batch.seq_id[j][s];
if (lctx.encoder_output_seq_ids[i].find(seq_id) != lctx.encoder_output_seq_ids[i].end()) {
if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
f = 0.0f;
}
}
data[h*(n_encoder_output*n_tokens) + j*n_encoder_output + i] = f;
data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_encoder_output; ++j) {
data[h*(n_encoder_output*n_tokens) + i*n_encoder_output + j] = -INFINITY;
for (int j = 0; j < n_output_enc; ++j) {
data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
}
}
}
@ -13809,6 +13810,7 @@ static int llama_decode_internal(
const auto n_ubatch = cparams.n_ubatch;
// TODO: simplify or deprecate
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
@ -14083,12 +14085,13 @@ static int llama_decode_internal(
//
static int llama_encode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
llama_batch batch) {
lctx.is_encoding = true;
const uint32_t n_tokens_all = batch_all.n_tokens;
if (n_tokens_all == 0) {
const uint32_t n_tokens = batch.n_tokens;
if (n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
return -1;
}
@ -14097,147 +14100,105 @@ static int llama_encode_internal(
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT(cparams.n_ubatch >= n_tokens_all && "encoder requires n_ubatch >= n_tokens");
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
const int64_t n_embd = hparams.n_embd;
lctx.n_queued_tokens += n_tokens;
uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
const int64_t n_embd = hparams.n_embd;
// TODO: simplify or deprecate
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
n_outputs = n_tokens_all;
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
};
for (uint32_t i = 0; i < n_outputs; ++i) {
for (uint32_t i = 0; i < n_tokens; ++i) {
lctx.output_ids[i] = i;
}
lctx.inp_embd_enc = NULL;
lctx.n_outputs = n_tokens;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
/* .n_tokens = */ (int32_t) n_tokens,
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
/* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
/* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
/* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
/* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
/* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
/* .all_pos_1 = */ batch_all.all_pos_1,
/* .all_seq_id = */ batch_all.all_seq_id,
};
const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
n_outputs_new = n_tokens;
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
// helpers for smoother batch API transition
// after deprecating the llama_eval calls, these will be removed
if (batch.pos == nullptr) {
pos.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
}
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
// helpers for smoother batch API transition
// after deprecating the llama_eval calls, these will be removed
if (u_batch.pos == nullptr) {
pos.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1;
}
u_batch.pos = pos.data();
}
if (u_batch.seq_id == nullptr) {
n_seq_id.resize(n_tokens);
seq_id.resize(n_tokens);
seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
n_seq_id[i] = 1;
seq_id[i].resize(1);
seq_id[i][0] = u_batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
u_batch.n_seq_id = n_seq_id.data();
u_batch.seq_id = seq_id_arr.data();
}
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
// the output is always the last tensor in the graph
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
ggml_backend_sched_alloc_graph(lctx.sched, gf);
llama_set_inputs(lctx, u_batch);
llama_graph_compute(lctx, gf, n_threads);
// extract embeddings
if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
const int32_t n_outputs_new = lctx.n_outputs;
lctx.encoder_output.resize((n_outputs_prev + n_outputs_new)*n_embd);
float * embd_out = lctx.encoder_output.data() + n_outputs_prev*n_embd;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
// extract output embeddings mask
lctx.encoder_output_seq_ids.resize(n_outputs_prev + n_outputs_new);
for (int i = 0; i < n_outputs_new; i++) {
for (int s = 0; s < u_batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = u_batch.seq_id[i][s];
lctx.encoder_output_seq_ids[i].insert(seq_id);
}
}
}
n_outputs_prev += lctx.n_outputs;
batch.pos = pos.data();
}
// set to total number of outputs in the batch, for use in llama_get_logits_ith
lctx.n_outputs = n_outputs;
if (batch.seq_id == nullptr) {
n_seq_id.resize(n_tokens);
seq_id.resize(n_tokens);
seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
n_seq_id[i] = 1;
seq_id[i].resize(1);
seq_id[i][0] = batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
batch.n_seq_id = n_seq_id.data();
batch.seq_id = seq_id_arr.data();
}
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
// the output embeddings after the final encoder normalization
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
ggml_backend_sched_alloc_graph(lctx.sched, gf);
llama_set_inputs(lctx, batch);
llama_graph_compute(lctx, gf, n_threads);
// extract embeddings
if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
lctx.embd_enc.resize(n_tokens*n_embd);
float * embd_out = lctx.embd_enc.data();
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
// remember the sequence ids used during the encoding - needed for cross attention later
lctx.seq_ids_enc.resize(n_tokens);
for (int i = 0; i < n_tokens; i++) {
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
lctx.seq_ids_enc[i].insert(seq_id);
}
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
@ -14246,7 +14207,6 @@ static int llama_encode_internal(
return 0;
}
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self;