diff --git a/src/llama.cpp b/src/llama.cpp index 2b9ace285..a4feb055b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13983,38 +13983,36 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // For causal attention, use only the previous KV cells // of the correct sequence for each token of the batch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_kv; ++i) { - float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = -INFINITY; + for (int i = 0; i < n_kv; ++i) { + float f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -fabs(lctx.kv_self.cells[i].pos - pos); } else { - if (hparams.use_alibi) { - f = -fabs(lctx.kv_self.cells[i].pos - pos); - } else { - f = 0.0f; - } + f = 0.0f; } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; + } + data[j*n_kv + i] = f; - // may need to cut off old tokens for sliding window - if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; } + data_swa[j*n_kv + i] = f; } } + } - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[i*n_kv + j] = -INFINITY; } } } else { @@ -14026,29 +14024,27 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int j = 0; j < n_tokens; ++j) { + const llama_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - if (hparams.use_alibi) { - f = -fabs(batch.pos[i] - batch.pos[j]); - } else { - f = 0.0f; - } - break; + for (int i = 0; i < n_tokens; ++i) { + float f = -INFINITY; + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + if (batch.seq_id[i][s] == seq_id) { + if (hparams.use_alibi) { + f = -fabs(batch.pos[i] - batch.pos[j]); + } else { + f = 0.0f; } + break; } - - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; } - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; - } + data[j*n_stride + i] = f; + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[j*n_stride + i] = -INFINITY; } } }