metal : implement soft_max_ext

This commit is contained in:
Georgi Gerganov 2023-11-29 12:44:47 +02:00
parent 1f5cd83275
commit e89597c062
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 75 additions and 31 deletions

8
ggml.h
View file

@ -1282,6 +1282,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// fused soft_max(a*scale + mask)
// mask is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale);
GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,