diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 4b0d2e5ad..cba5050fa 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -58,10 +58,10 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -88,7 +88,9 @@ static __global__ void rope_neox( float cur_rot = inv_ndims * ic - ib; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + const float freq_factor = has_freq_facs ? freq_factors[col/2] : 1.0f; + + const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -164,7 +166,7 @@ static void rope_cuda( template static void rope_neox_cuda( const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); @@ -175,15 +177,32 @@ static void rope_neox_cuda( const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); - } else { - rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } + else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } + } + else { + if (freq_factors == nullptr) { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } + else { + rope_neox<<>>( + x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, inv_ndims, freq_factors + ); + } } } @@ -214,17 +233,17 @@ static void rope_cuda_f32( static void rope_neox_cuda_f16( const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float* freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -259,11 +278,18 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const float* freq_factors = nullptr; const int32_t * pos = nullptr; if ((mode & 1) == 0) { GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->ne[0] == ne2); pos = (const int32_t *) src1_d; + + if (dst->src[2] != nullptr) { + GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); + freq_factors = (const float*) dst->src[2]->data; + } } const bool is_neox = mode & 2; @@ -280,12 +306,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/ggml.c b/ggml.c index 4bd911528..ecbb0db80 100644 --- a/ggml.c +++ b/ggml.c @@ -6275,6 +6275,13 @@ static struct ggml_tensor * ggml_rope_impl( return result; } +struct ggml_tensor * ggml_rope_with_freq_factors( + struct ggml_tensor* rope_tensor, + struct ggml_tensor* freq_factors) { + rope_tensor->src[2] = freq_factors; + return rope_tensor; +} + struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, @@ -18915,21 +18922,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_rope_back(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - n_orig_ctx, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - xpos_base, - xpos_down), + ggml_rope_with_freq_factors( + ggml_rope_back(ctx, + tensor->grad, + src1, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down), + tensor->src[2]), zero_table); } } break; @@ -18954,22 +18963,24 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_rope_impl(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - n_orig_ctx, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - xpos_base, - xpos_down, - false), + ggml_rope_with_freq_factors( + ggml_rope_impl(ctx, + tensor->grad, + src1, + n_dims, + mode, + n_ctx, + n_orig_ctx, + freq_base, + freq_scale, + ext_factor, + attn_factor, + beta_fast, + beta_slow, + xpos_base, + xpos_down, + false), + tensor->src[2]), zero_table); } } break; diff --git a/llama.cpp b/llama.cpp index d26fe559a..de2b91ac4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -304,6 +304,9 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_LONG_FACTORS, + LLM_KV_ROPE_SCALING_SHORT_FACTORS, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -381,6 +384,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_LONG_FACTORS, "%s.rope.scaling.freq_long_factors" }, + { LLM_KV_ROPE_SCALING_SHORT_FACTORS, "%s.rope.scaling.freq_short_factors" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -1754,6 +1760,10 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; + std::vector rope_long_factors; + std::vector rope_short_factors; + float rope_attn_factor = 1.0f; + // for State Space Models uint32_t ssm_d_conv = 0; uint32_t ssm_d_inner = 0; @@ -1789,6 +1799,10 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->rope_long_factors != other.rope_long_factors) return true; + if (this->rope_short_factors != other.rope_short_factors) return true; + if (this->rope_attn_factor != other.rope_attn_factor) return true; + if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; @@ -2246,6 +2260,8 @@ struct llama_context { struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * freq_factors = nullptr; // F32 [kv_size / 2] + // control vectors struct llama_control_vector cvec; }; @@ -3306,6 +3322,39 @@ struct llama_model_loader { return get_arr_n(llm_kv(kid), result, required); } + template + bool get_arr(const std::string& key, std::vector& result, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { + throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + } + + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + + result.resize(arr_info.length); + result.assign((T*)arr_info.data, (T*)arr_info.data + arr_info.length); + + return true; + } + + template + bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + return get_arr(llm_kv(kid), result, required); + } + template bool get_key(const std::string & key, T & result, const bool required = true) { auto it = kv_overrides.find(key); @@ -3849,6 +3898,14 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + ml.get_arr(LLM_KV_ROPE_SCALING_LONG_FACTORS, hparams.rope_long_factors, false); + ml.get_arr(LLM_KV_ROPE_SCALING_SHORT_FACTORS, hparams.rope_short_factors, false); + + GGML_ASSERT(hparams.rope_long_factors.size() == 0 || hparams.rope_long_factors.size() == hparams.n_embd / hparams.n_head / 2); + GGML_ASSERT(hparams.rope_long_factors.size() == hparams.rope_short_factors.size()); + + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + // sanity check for n_rot (optional) { hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; @@ -6821,6 +6878,8 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); + lctx.freq_factors = build_freq_factors(); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions @@ -6832,6 +6891,9 @@ struct llm_build_context { 0), lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + tmp = ggml_rope_with_freq_factors(tmp, lctx.freq_factors); + cb(tmp, "K_shifted", il); ggml_build_forward_expand(gf, tmp); } @@ -6934,6 +6996,20 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor* build_freq_factors() { + + if (hparams.rope_long_factors.empty() || hparams.rope_short_factors.empty()) { + lctx.freq_factors = nullptr; + return nullptr; + } + + lctx.freq_factors = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_embd_head_k / 2); + cb(lctx.freq_factors, "freq_factors", -1); + ggml_set_input(lctx.freq_factors); + + return lctx.freq_factors; + } + struct ggml_tensor * build_inp_out_ids() { lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); cb(lctx.inp_out_ids, "inp_out_ids", -1); @@ -9052,6 +9128,9 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // rope freq factors for 128k context + struct ggml_tensor* freq_factors = build_freq_factors(); + for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -9092,6 +9171,7 @@ struct llm_build_context { ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Qcur = ggml_rope_with_freq_factors(Qcur, freq_factors); cb(Qcur, "Qcur", il); Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); @@ -9101,6 +9181,7 @@ struct llm_build_context { ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Kcur = ggml_rope_with_freq_factors(Kcur, freq_factors); cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, @@ -10890,6 +10971,22 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + if (lctx.freq_factors) { + auto freq_dim = hparams.n_embd_head_k / 2; + + GGML_ASSERT(lctx.freq_factors->ne[0] == freq_dim); + GGML_ASSERT(hparams.rope_long_factors.size() == freq_dim); + GGML_ASSERT(hparams.rope_short_factors.size() == freq_dim); + + auto max_pos = batch.n_tokens > 0 && batch.pos != nullptr ? *std::max_element(batch.pos, batch.pos + batch.n_tokens) : batch.n_tokens - 1; + if (max_pos + 1 > hparams.n_yarn_orig_ctx) { + ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_long_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); + } + else { + ggml_backend_tensor_set(lctx.freq_factors, hparams.rope_short_factors.data(), 0, freq_dim * ggml_element_size(lctx.freq_factors)); + } + } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; @@ -15417,6 +15514,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {