ggml : refactor rope norm/neox (#7634)
* ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci
This commit is contained in:
parent
9973e81c5c
commit
2b3389677a
19 changed files with 485 additions and 732 deletions
|
@ -3898,11 +3898,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
{
|
||||
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
if (is_glm) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
|
@ -4401,7 +4396,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
|||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
const float freq_base = ((float *) dst->op_params)[5];
|
||||
const float freq_scale = ((float *) dst->op_params)[6];
|
||||
const float ext_factor = ((float *) dst->op_params)[7];
|
||||
|
@ -4410,12 +4405,12 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
|
|||
const float beta_slow = ((float *) dst->op_params)[10];
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
GGML_ASSERT(!is_glm);
|
||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
||||
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
if (is_neox) {
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
|
@ -6485,9 +6480,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
|||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
return !is_glm;
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
|
@ -6992,15 +6986,15 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
|
|||
} else if (tensor->op == GGML_OP_ROPE) {
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ggml_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_orig_ggml_ctx = ((int32_t *) tensor->op_params)[4];
|
||||
//const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
|
||||
const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
|
||||
float freq_base = ((float *) tensor->op_params)[5];
|
||||
float freq_scale = ((float *) tensor->op_params)[6];
|
||||
float ext_factor = ((float *) tensor->op_params)[7];
|
||||
float attn_factor = ((float *) tensor->op_params)[8];
|
||||
float beta_fast = ((float *) tensor->op_params)[9];
|
||||
float beta_slow = ((float *) tensor->op_params)[10];
|
||||
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ggml_ctx, n_orig_ggml_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
} else if (tensor->op == GGML_OP_UNARY) {
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue