ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-22 13:46:23 +03:00
parent c11d05fec0
commit f725ca90fb
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 105 additions and 34 deletions

View file

@ -1120,11 +1120,11 @@ struct test_soft_max : public test_case {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * mask = nullptr;
if (this->mask) {
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]);
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
}
ggml_tensor * pos = nullptr;
if (max_bias > 0.0f) {
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]);
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
}
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
return out;