From aaa20e1f10e6bb3c6880317e0f60c0847f15555b Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 17 Feb 2024 18:16:27 +0100 Subject: [PATCH] test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --- tests/test-backend-ops.cpp | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 88638c164..30a7d1f5a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1102,10 +1102,15 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * b = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } - ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); - ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, c, scale, max_bias); + ggml_tensor * mask = nullptr; + if (this->mask) { + mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + } + ggml_tensor * pos = nullptr; + if (max_bias > 0.0f) { + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); + } + ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); return out; } }; @@ -2061,6 +2066,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); +#if 0 std::uniform_int_distribution<> dist_ne1(1, 50); int exponent = 1; while (exponent < (1 << 17)) { @@ -2074,6 +2080,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op exponent <<= 1; } +#endif + for (bool mask : {false, true}) { + for (float max_bias : {0.0f, 8.0f}) { + for (float scale : {1.0f, 0.1f}) { + for (int64_t ne0 : {16, 1024}) { + for (int64_t ne1 : {16, 1024}) { + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); + } + } + } + } + } test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));