diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a529dfdb5..a6486f34e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1057,19 +1057,23 @@ struct test_soft_max : public test_case { const ggml_type type; const std::array ne; const float scale; + const bool mask; std::string vars() override { - return VARS_TO_STR3(type, ne, scale); + return VARS_TO_STR4(type, ne, scale, mask); } test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, - float scale = 1.0f) - : type(type), ne(ne), scale(scale) {} + float scale = 1.0f, + bool mask = false) + : type(type), ne(ne), scale(scale), mask(mask) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * out = ggml_soft_max_ext(ctx, a, nullptr, scale); + ggml_tensor * b = nullptr; + if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); return out; } }; @@ -1827,7 +1831,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op exponent <<= 1; } - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 16, 1, 1}, 0.1f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B