ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
This commit is contained in:
parent
c11d05fec0
commit
f725ca90fb
6 changed files with 105 additions and 34 deletions
|
@ -1120,11 +1120,11 @@ struct test_soft_max : public test_case {
|
|||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * mask = nullptr;
|
||||
if (this->mask) {
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]);
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
||||
}
|
||||
ggml_tensor * pos = nullptr;
|
||||
if (max_bias > 0.0f) {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]);
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
||||
}
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
||||
return out;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue