ggml : refactor rope norm/neox (#7634)
* ggml : unify rope norm/neox (CPU) * ggml : fix compile warning * ggml : remove GLM rope mode ggml-ci * metal : better rope implementation ggml-ci * cuda : better rope implementation ggml-ci * naming : n_orig_ctx -> n_ctx_orig ggml-ci * dev : add reminders to update backends ggml-ci * vulkan : fix ggml_rope_ext() usage * cuda : fix array size + indents ggml-ci
This commit is contained in:
parent
9973e81c5c
commit
2b3389677a
19 changed files with 485 additions and 732 deletions
|
@ -1192,7 +1192,7 @@ static void ggml_vk_rope(
|
|||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
|
||||
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
|
||||
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
||||
int32_t ne01, int32_t ne02, int32_t ne03,
|
||||
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
||||
|
@ -1221,14 +1221,14 @@ static void ggml_vk_rope(
|
|||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
int32_t n_dims, mode, n_orig_ctx;
|
||||
int32_t n_dims, mode, n_ctx_orig;
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
uint32_t nb00, nb01, nb02, nb03;
|
||||
int32_t ne0;
|
||||
uint32_t nb0, nb1, nb2, nb3;
|
||||
} pushConsts {
|
||||
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
||||
n_dims, mode, n_orig_ctx,
|
||||
n_dims, mode, n_ctx_orig,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||
nb00, nb01, nb02, nb03,
|
||||
ne0,
|
||||
|
@ -1692,13 +1692,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
||||
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
||||
|
||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
||||
|
||||
GGML_ASSERT(ne10 == ne02);
|
||||
GGML_ASSERT(src0t == dstt);
|
||||
// const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
@ -1708,7 +1711,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
ggml_vk_rope(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
||||
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
||||
);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue