cuda : implement soft_max_ext

This commit is contained in:
Georgi Gerganov 2023-11-29 15:34:20 +02:00
parent e89597c062
commit 88519fbf97
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 28 additions and 14 deletions

6
ggml.c
View file

@ -4829,6 +4829,12 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * mask,
float scale,
bool inplace) {
if (mask) {
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
}
bool is_node = false;
if (a->grad) {