From 0d7240b32045e935732985b1a46376ced95f8206 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Wed, 19 Jul 2023 14:16:27 +0800 Subject: [PATCH] modified rope for cuda --- ggml-cuda.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 923d10f76..e316260fa 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2972,13 +2972,19 @@ inline void ggml_cuda_op_rope( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + float freq_base; + float freq_scale; + const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; const int n_ctx = ((int32_t *) src1->data)[3]; + memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); - const float theta_scale = powf(10000.0, -2.0f/n_dims); - const float p = ((mode & 1) == 0 ? n_past + i02 : i02); + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02); + const float p = p0 * freq_scale; bool is_glm = mode & 4;