test-backend-ops : add null pos test to soft_max

test-backend-ops : replace soft_max tests

ggml-ci
This commit is contained in:
slaren 2024-02-17 18:16:27 +01:00
parent 1657f92d2f
commit aaa20e1f10

View file

@ -1102,10 +1102,15 @@ struct test_soft_max : public test_case {
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = nullptr; ggml_tensor * mask = nullptr;
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } if (this->mask) {
ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, c, scale, max_bias); }
ggml_tensor * pos = nullptr;
if (max_bias > 0.0f) {
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
}
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
return out; return out;
} }
}; };
@ -2061,6 +2066,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
#if 0
std::uniform_int_distribution<> dist_ne1(1, 50); std::uniform_int_distribution<> dist_ne1(1, 50);
int exponent = 1; int exponent = 1;
while (exponent < (1 << 17)) { while (exponent < (1 << 17)) {
@ -2074,6 +2080,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
exponent <<= 1; exponent <<= 1;
} }
#endif
for (bool mask : {false, true}) {
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
}
}
}
}
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));