tests : add rope tests
ggml-ci
This commit is contained in:
parent
cce3dcffc5
commit
9d5605f965
1 changed files with 26 additions and 17 deletions
|
@ -1142,22 +1142,25 @@ struct test_rope : public test_case {
|
||||||
int n_dims;
|
int n_dims;
|
||||||
int mode;
|
int mode;
|
||||||
int n_ctx;
|
int n_ctx;
|
||||||
|
float fs; // freq_scale
|
||||||
|
float ef; // ext_factor
|
||||||
|
float af; // attn_factor
|
||||||
bool ff;
|
bool ff;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
|
return VARS_TO_STR9(type, ne, n_dims, mode, n_ctx, fs, ef, af, ff);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rope(ggml_type type = GGML_TYPE_F32,
|
test_rope(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
||||||
int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
|
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false)
|
||||||
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
|
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
||||||
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
|
||||||
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
|
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2213,20 +2216,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||||
|
|
||||||
|
for (float fs : { 1.0f, 1.4245f }) {
|
||||||
|
for (float ef : { 0.0f, 0.7465f }) {
|
||||||
|
for (float af : { 1.0f, 1.4245f }) {
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
// TODO: ff not supported yet for !neox
|
// TODO: ff not supported yet for !neox
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
|
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 7B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
|
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 13B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
|
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 30B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
|
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 65B
|
||||||
|
|
||||||
for (bool ff : {false, true}) { // freq_factors
|
for (bool ff : {false, true}) { // freq_factors
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 40B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 40B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, fs, ef, af, ff)); // neox (stablelm)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, fs, ef, af, ff)); // neox (phi-2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue