From 0faf92e74cab1d68070d37508c2445b5ccaeee49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 May 2024 17:13:11 +0300 Subject: [PATCH] ggml : require mask when using ALiBi ggml-ci --- ggml.c | 9 +++++++++ tests/test-backend-ops.cpp | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 376a67cfb..4ee5d24af 100644 --- a/ggml.c +++ b/ggml.c @@ -5657,6 +5657,10 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->ne[1] >= a->ne[1]); } + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + bool is_node = false; if (a->grad) { @@ -6440,6 +6444,7 @@ struct ggml_tensor * ggml_flash_attn_ext( float max_bias) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); @@ -6449,6 +6454,10 @@ struct ggml_tensor * ggml_flash_attn_ext( //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + bool is_node = false; if (q->grad || k->grad || v->grad) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ab94abc72..731788b95 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2126,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #endif for (bool mask : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { + if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { @@ -2139,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {