minor : style

This commit is contained in:
Georgi Gerganov 2024-05-16 13:33:01 +03:00
parent 352c3859a7
commit f4cb482c62
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 15 additions and 19 deletions

View file

@ -182,22 +182,19 @@ static void rope_neox_cuda(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, inv_ndims, freq_factors
); );
} } else {
else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, inv_ndims, freq_factors
); );
} }
} } else {
else {
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, inv_ndims, freq_factors
); );
} } else {
else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors theta_scale, inv_ndims, freq_factors

View file

@ -1799,9 +1799,9 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true; if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
if (this->rope_long_factors != other.rope_long_factors) return true; if (this->rope_long_factors != other.rope_long_factors) return true;
if (this->rope_short_factors != other.rope_short_factors) return true; if (this->rope_short_factors != other.rope_short_factors) return true;
if (this->rope_attn_factor != other.rope_attn_factor) return true; if (this->rope_attn_factor != other.rope_attn_factor) return true;
if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true;
if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true;
@ -3323,7 +3323,7 @@ struct llama_model_loader {
} }
template<typename T> template<typename T>
bool get_arr(const std::string& key, std::vector<T>& result, const bool required = true) { bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
const int kid = gguf_find_key(meta, key.c_str()); const int kid = gguf_find_key(meta, key.c_str());
if (kid < 0) { if (kid < 0) {
@ -3342,10 +3342,10 @@ struct llama_model_loader {
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value)); GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value));
GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same<T, int>::value)); GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same<T, int>::value));
result.resize(arr_info.length); result.resize(arr_info.length);
result.assign((const T*)arr_info.data, (const T*)arr_info.data + arr_info.length); result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
return true; return true;
} }
@ -3898,7 +3898,7 @@ static void llm_load_hparams(
} }
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false); ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false);
ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false); ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false);
GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2); GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2);
@ -6994,8 +6994,7 @@ struct llm_build_context {
return lctx.inp_pos; return lctx.inp_pos;
} }
struct ggml_tensor* build_freq_factors() { struct ggml_tensor * build_freq_factors() {
if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) { if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) {
lctx.freq_factors = nullptr; lctx.freq_factors = nullptr;
return nullptr; return nullptr;
@ -10968,18 +10967,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
if (lctx.freq_factors) { if (lctx.freq_factors) {
auto freq_dim = hparams.n_embd_head_k / 2; // TODO: this might have to be hparams.n_rot instead of hparams.n_embd_head_k, but maybe it does not matter
const auto freq_dim = hparams.n_embd_head_k / 2;
GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim); GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim);
GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim);
GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim);
// choose long/short freq factors based on the context size // choose long/short freq factors based on the context size
auto n_ctx = llama_n_ctx(&lctx); const auto n_ctx = llama_n_ctx(&lctx);
if (n_ctx > hparams.n_yarn_orig_ctx) { if (n_ctx > hparams.n_yarn_orig_ctx) {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
} } else {
else {
ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors));
} }
} }