diff --git a/src/llama.cpp b/src/llama.cpp index df377b7d7..d1fa683aa 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2697,9 +2697,10 @@ struct llama_context { // whether we are computing encoder output or decoder output bool is_encoding = false; + // output of the encoder part of the encoder-decoder models - std::vector encoder_output; - std::vector > encoder_output_seq_ids; + std::vector embd_enc; + std::vector> seq_ids_enc; // memory buffers used to evaluate the model std::vector buf_compute_meta; @@ -7974,7 +7975,7 @@ struct llm_build_context { n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), - n_outputs_enc (worst_case ? n_tokens : lctx.encoder_output.size() / hparams.n_embd), + n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), @@ -12716,7 +12717,7 @@ struct llm_build_context { model.layers[il].ffn_down_enc, NULL, NULL, NULL, model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, cb, il); cb(cur, "ffn_out", il); } @@ -12742,8 +12743,8 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); } else { - struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); - struct ggml_tensor * pos_buckets_dec = llm_build_pos_bucket(true); + struct ggml_tensor * embd_enc = llm_build_inp_embd_enc(); + struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true); struct ggml_tensor * KQ_mask_dec = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross(); @@ -12794,7 +12795,7 @@ struct llm_build_context { cb(kq, "kq", il); struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; - struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_buckets_dec, attn_rel_b); + struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b); struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); cb(kq_b, "kq_b", il); @@ -12838,7 +12839,7 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -13298,26 +13299,26 @@ static void llama_set_s_copy(llama_context & lctx) { } } -static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t num_buckets, bool bidirectional) { +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; if (bidirectional) { - num_buckets >>= 1; + n_buckets >>= 1; } - const int64_t max_exact = num_buckets >> 1; + const int64_t max_exact = n_buckets >> 1; int32_t relative_position = x - y; int32_t relative_bucket = 0; if (bidirectional) { - relative_bucket += (relative_position > 0) * num_buckets; + relative_bucket += (relative_position > 0) * n_buckets; relative_position = abs(relative_position); } else { relative_position = -std::min(relative_position, 0); } - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, num_buckets - 1); + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); return relative_bucket; } @@ -13634,13 +13635,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (!lctx.is_encoding && lctx.inp_embd_enc) { assert(lctx.inp_embd_enc->type == GGML_TYPE_F32); - assert(ggml_nelements(lctx.inp_embd_enc) == lctx.encoder_output.size()); + assert(ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size()); - ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.encoder_output.data(), 0, ggml_nbytes(lctx.inp_embd_enc)); + ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc)); } if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) { - const int64_t n_encoder_output = lctx.encoder_output.size() / hparams.n_embd; + const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd; const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); @@ -13649,21 +13650,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_encoder_output; ++i) { + for (int i = 0; i < n_output_enc; ++i) { float f = -INFINITY; for (int s = 0; s < batch.n_seq_id[j]; ++s) { const llama_seq_id seq_id = batch.seq_id[j][s]; - if (lctx.encoder_output_seq_ids[i].find(seq_id) != lctx.encoder_output_seq_ids[i].end()) { + if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) { f = 0.0f; } } - data[h*(n_encoder_output*n_tokens) + j*n_encoder_output + i] = f; + data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; } } for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_encoder_output; ++j) { - data[h*(n_encoder_output*n_tokens) + i*n_encoder_output + j] = -INFINITY; + for (int j = 0; j < n_output_enc; ++j) { + data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; } } } @@ -13809,6 +13810,7 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; + // TODO: simplify or deprecate std::vector pos; std::vector n_seq_id; std::vector seq_id_arr; @@ -14083,12 +14085,13 @@ static int llama_decode_internal( // static int llama_encode_internal( llama_context & lctx, - llama_batch batch_all) { // TODO: rename back to batch + llama_batch batch) { lctx.is_encoding = true; - const uint32_t n_tokens_all = batch_all.n_tokens; - if (n_tokens_all == 0) { + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); return -1; } @@ -14097,147 +14100,105 @@ static int llama_encode_internal( const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - GGML_ASSERT(n_tokens_all <= cparams.n_batch); - - GGML_ASSERT(cparams.n_ubatch >= n_tokens_all && "encoder requires n_ubatch >= n_tokens"); + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); if (lctx.t_compute_start_us == 0) { lctx.t_compute_start_us = ggml_time_us(); } - lctx.n_queued_tokens += n_tokens_all; - const int64_t n_embd = hparams.n_embd; + lctx.n_queued_tokens += n_tokens; - uint32_t n_outputs = 0; - uint32_t n_outputs_prev = 0; - - const auto n_ubatch = cparams.n_ubatch; + const int64_t n_embd = hparams.n_embd; + // TODO: simplify or deprecate std::vector pos; std::vector n_seq_id; std::vector seq_id_arr; std::vector> seq_id; - n_outputs = n_tokens_all; - // reserve output buffer - if (llama_output_reserve(lctx, n_outputs) < n_outputs) { - LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); + if (llama_output_reserve(lctx, n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; - for (uint32_t i = 0; i < n_outputs; ++i) { + for (uint32_t i = 0; i < n_tokens; ++i) { lctx.output_ids[i] = i; } lctx.inp_embd_enc = NULL; + lctx.n_outputs = n_tokens; - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { - const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); - llama_batch u_batch = { - /* .n_tokens = */ (int32_t) n_tokens, - /* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr, - /* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr, - /* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr, - /* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr, - /* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr, - /* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr, - /* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1, - /* .all_pos_1 = */ batch_all.all_pos_1, - /* .all_seq_id = */ batch_all.all_seq_id, - }; + const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT(n_threads > 0); - // count the outputs in this u_batch - { - int32_t n_outputs_new = 0; - - n_outputs_new = n_tokens; - - // needs to happen before the graph is built - lctx.n_outputs = n_outputs_new; + // helpers for smoother batch API transition + // after deprecating the llama_eval calls, these will be removed + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; } - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - GGML_ASSERT(n_threads > 0); - - // helpers for smoother batch API transition - // after deprecating the llama_eval calls, these will be removed - if (u_batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = u_batch.all_pos_0 + i*u_batch.all_pos_1; - } - - u_batch.pos = pos.data(); - } - - if (u_batch.seq_id == nullptr) { - n_seq_id.resize(n_tokens); - seq_id.resize(n_tokens); - seq_id_arr.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - n_seq_id[i] = 1; - seq_id[i].resize(1); - seq_id[i][0] = u_batch.all_seq_id; - seq_id_arr[i] = seq_id[i].data(); - } - - u_batch.n_seq_id = n_seq_id.data(); - u_batch.seq_id = seq_id_arr.data(); - } - - ggml_backend_sched_reset(lctx.sched); - ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - - ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); - - // the output is always the last tensor in the graph - struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; - - // token or sequence embeddings - embd = gf->nodes[gf->n_nodes - 1]; - - GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_inputs(lctx, u_batch); - - llama_graph_compute(lctx, gf, n_threads); - - // extract embeddings - if (embd) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); - GGML_ASSERT(backend_embd != nullptr); - - // extract token embeddings - GGML_ASSERT(lctx.embd != nullptr); - const int32_t n_outputs_new = lctx.n_outputs; - lctx.encoder_output.resize((n_outputs_prev + n_outputs_new)*n_embd); - float * embd_out = lctx.encoder_output.data() + n_outputs_prev*n_embd; - - if (n_outputs_new) { - GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); - GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size); - ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); - } - - // extract output embeddings mask - lctx.encoder_output_seq_ids.resize(n_outputs_prev + n_outputs_new); - for (int i = 0; i < n_outputs_new; i++) { - for (int s = 0; s < u_batch.n_seq_id[i]; s++) { - llama_seq_id seq_id = u_batch.seq_id[i][s]; - lctx.encoder_output_seq_ids[i].insert(seq_id); - } - } - } - n_outputs_prev += lctx.n_outputs; + batch.pos = pos.data(); } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - lctx.n_outputs = n_outputs; + if (batch.seq_id == nullptr) { + n_seq_id.resize(n_tokens); + seq_id.resize(n_tokens); + seq_id_arr.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + n_seq_id[i] = 1; + seq_id[i].resize(1); + seq_id[i][0] = batch.all_seq_id; + seq_id_arr[i] = seq_id[i].data(); + } + + batch.n_seq_id = n_seq_id.data(); + batch.seq_id = seq_id_arr.data(); + } + + ggml_backend_sched_reset(lctx.sched); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + + ggml_cgraph * gf = llama_build_graph(lctx, batch, false); + + // the output embeddings after the final encoder normalization + struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1]; + + GGML_ASSERT(strcmp(embd->name, "result_norm") == 0); + + ggml_backend_sched_alloc_graph(lctx.sched, gf); + + llama_set_inputs(lctx, batch); + + llama_graph_compute(lctx, gf, n_threads); + + // extract embeddings + if (embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); + GGML_ASSERT(backend_embd != nullptr); + + // extract token embeddings + GGML_ASSERT(lctx.embd != nullptr); + + lctx.embd_enc.resize(n_tokens*n_embd); + float * embd_out = lctx.embd_enc.data(); + + ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + + // remember the sequence ids used during the encoding - needed for cross attention later + lctx.seq_ids_enc.resize(n_tokens); + for (int i = 0; i < n_tokens; i++) { + for (int s = 0; s < batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = batch.seq_id[i][s]; + lctx.seq_ids_enc[i].insert(seq_id); + } + } + } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. @@ -14246,7 +14207,6 @@ static int llama_encode_internal( return 0; } - // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { auto & kv_self = lctx.kv_self;