ggml : refactor rope norm/neox (#7634)

* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-06-05 11:29:20 +03:00 committed by GitHub
parent 9973e81c5c
commit 2b3389677a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 485 additions and 732 deletions

View file

@ -162,12 +162,12 @@ int main(int /*argc*/, const char ** /*argv*/) {
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
// 100, 101, 102, ..., 172
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024);
struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
// -67, -67, -67, ..., -67
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
// 33, 34, 35, ..., 105
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024);
struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode);
ggml_cgraph * gf = ggml_new_graph(ctx0);