make freq factors only depend on ctx size

This commit is contained in:
liuwei 2024-05-11 17:25:08 +00:00 committed by Georgi Gerganov
parent c5569311a4
commit 6333ed1a30
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -10978,12 +10978,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
auto max_pos = batch.n_tokens > 0 && batch.pos != nullptr ? *std::max_element(batch.pos, batch.pos + batch.n_tokens) : batch.n_tokens - 1; auto n_ctx = llama_n_ctx(&lctx);
if ((uint32_t)(max_pos + 1) > hparams.n_yarn_orig_ctx) { if (n_ctx > hparams.n_yarn_orig_ctx) {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
} }
else { else {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
} }
} }