ggml : refactor rope norm/neox (#7634)

* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-06-05 11:29:20 +03:00 committed by GitHub
parent 9973e81c5c
commit 2b3389677a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 485 additions and 732 deletions

View file

@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
{
const int mode = ((const int32_t *) dst->op_params)[2];
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
if (is_glm) {
return nullptr;
}
if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const float freq_base = ((float *) dst->op_params)[5];
const float freq_scale = ((float *) dst->op_params)[6];
const float ext_factor = ((float *) dst->op_params)[7];
@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
const float beta_slow = ((float *) dst->op_params)[10];
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
GGML_ASSERT(!is_glm);
#pragma message("TODO: update rope NORM mode to match NEOX mode")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
if (is_neox) {
const float theta_scale = powf(freq_base, -2.0f/n_dims);
@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
const bool is_glm = mode & 4;
return !is_glm;
return true;
} break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
} else if (tensor->op == GGML_OP_ROPE) {
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3];
const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4];
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
float freq_base = ((float *) tensor->op_params)[5];
float freq_scale = ((float *) tensor->op_params)[6];
float ext_factor = ((float *) tensor->op_params)[7];
float attn_factor = ((float *) tensor->op_params)[8];
float beta_fast = ((float *) tensor->op_params)[9];
float beta_slow = ((float *) tensor->op_params)[10];
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
} else if (tensor->op == GGML_OP_UNARY) {
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU: