metal : implement soft_max_ext
This commit is contained in:
parent
1f5cd83275
commit
e89597c062
5 changed files with 75 additions and 31 deletions
18
ggml.c
18
ggml.c
|
@ -4826,6 +4826,8 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
|
|||
static struct ggml_tensor * ggml_soft_max_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
|
@ -4835,9 +4837,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
float params[] = { scale };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_SOFT_MAX;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = mask;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -4845,13 +4851,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||
struct ggml_tensor * ggml_soft_max(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_soft_max_impl(ctx, a, false);
|
||||
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_soft_max_impl(ctx, a, true);
|
||||
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale) {
|
||||
return ggml_soft_max_impl(ctx, a, mask, scale, false);
|
||||
}
|
||||
|
||||
// ggml_soft_max_back
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue