ggml : refactor rope norm/neox (#7634)

* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-06-05 11:29:20 +03:00 committed by GitHub
parent 9973e81c5c
commit 2b3389677a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 485 additions and 732 deletions

View file

@ -1141,7 +1141,7 @@ struct test_rope : public test_case {
const std::array<int64_t, 4> ne_a;
int n_dims;
int mode;
int n_ctx;
int n_ctx; // used to generate positions
float fs; // freq_scale
float ef; // ext_factor
float af; // attn_factor
@ -1168,7 +1168,7 @@ struct test_rope : public test_case {
}
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
return out;
}
@ -1615,7 +1615,7 @@ struct llama_hparams {
// cparams
static constexpr uint32_t n_ctx = 512; // user-specified context size
static constexpr uint32_t n_orig_ctx = n_ctx;
static constexpr uint32_t n_ctx_orig = n_ctx;
// batch
int32_t n_tokens;
@ -1806,13 +1806,13 @@ struct test_llama : public test_llm {
Qcur = ggml_rope_ext(
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -1931,12 +1931,12 @@ struct test_falcon : public test_llm {
// using mode = 2 for neox mode
Qcur = ggml_rope_ext(
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (float ef : { 0.0f, 0.7465f }) {
for (float af : { 1.0f, 1.4245f }) {
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// TODO: ff not supported yet for !neox
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
if (all) {
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
}
for (bool ff : {false, true}) { // freq_factors
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
if (all) {
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
}
if (all) {
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
@ -2256,6 +2256,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
}
}
all = false;
}
}