minor : style
This commit is contained in:
parent
352c3859a7
commit
f4cb482c62
2 changed files with 15 additions and 19 deletions
|
@ -182,22 +182,19 @@ static void rope_neox_cuda(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, inv_ndims, freq_factors
|
||||||
);
|
);
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, inv_ndims, freq_factors
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, inv_ndims, freq_factors
|
||||||
);
|
);
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
||||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||||
theta_scale, inv_ndims, freq_factors
|
theta_scale, inv_ndims, freq_factors
|
||||||
|
|
|
@ -6995,7 +6995,6 @@ struct llm_build_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * build_freq_factors() {
|
struct ggml_tensor * build_freq_factors() {
|
||||||
|
|
||||||
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
|
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
|
||||||
lctx.freq_factors = nullptr;
|
lctx.freq_factors = nullptr;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -10968,18 +10967,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lctx.freq_factors) {
|
if (lctx.freq_factors) {
|
||||||
auto freq_dim = hparams.n_embd_head_k / 2;
|
// TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter
|
||||||
|
const auto freq_dim = hparams.n_embd_head_k / 2;
|
||||||
|
|
||||||
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
|
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
|
||||||
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
|
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
|
||||||
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
|
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
|
||||||
|
|
||||||
// choose long/short freq factors based on the context size
|
// choose long/short freq factors based on the context size
|
||||||
auto n_ctx = llama_n_ctx(&lctx);
|
const auto n_ctx = llama_n_ctx(&lctx);
|
||||||
if (n_ctx > hparams.n_yarn_orig_ctx) {
|
if (n_ctx > hparams.n_yarn_orig_ctx) {
|
||||||
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
|
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
|
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue