test-backend-ops : test mask parameter of ggml_soft_max_ext
This commit is contained in:
parent
308f279622
commit
8bd38fe32d
1 changed files with 10 additions and 5 deletions
|
@ -1057,19 +1057,23 @@ struct test_soft_max : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
const float scale;
|
const float scale;
|
||||||
|
const bool mask;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR3(type, ne, scale);
|
return VARS_TO_STR4(type, ne, scale, mask);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_soft_max(ggml_type type = GGML_TYPE_F32,
|
test_soft_max(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
||||||
float scale = 1.0f)
|
float scale = 1.0f,
|
||||||
: type(type), ne(ne), scale(scale) {}
|
bool mask = false)
|
||||||
|
: type(type), ne(ne), scale(scale), mask(mask) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, nullptr, scale);
|
ggml_tensor * b = nullptr;
|
||||||
|
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
|
||||||
|
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1827,7 +1831,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
exponent <<= 1;
|
exponent <<= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 16, 1, 1}, 0.1f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
|
||||||
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
|
||||||
|
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue