minor : style

This commit is contained in:
Georgi Gerganov 2024-05-16 13:33:01 +03:00
parent 352c3859a7
commit f4cb482c62
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 15 additions and 19 deletions

View file

@ -182,22 +182,19 @@ static void rope_neox_cuda(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
else {
} else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
}
else {
} else {
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
);
}
else {
} else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors

View file

@ -6995,7 +6995,6 @@ struct llm_build_context {
}
struct ggml_tensor * build_freq_factors() {
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
lctx.freq_factors = nullptr;
return nullptr;
@ -10968,18 +10967,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
if (lctx.freq_factors) {
auto freq_dim = hparams.n_embd_head_k / 2;
// TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter
const auto freq_dim = hparams.n_embd_head_k / 2;
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
// choose long/short freq factors based on the context size
auto n_ctx = llama_n_ctx(&lctx);
const auto n_ctx = llama_n_ctx(&lctx);
if (n_ctx > hparams.n_yarn_orig_ctx) {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
}
else {
} else {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
}
}