From 0fe2d560011be39136e882f56a4bddba26d8452c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 15 Feb 2024 13:11:13 +0200 Subject: [PATCH] ggml : deprecate ggml_alibi --- ggml-metal.m | 2 +- ggml.h | 5 +++-- llama.cpp | 24 ++++++++++++++++++++++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index eb7afd18f..a0d9aaaba 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -809,7 +809,7 @@ static bool ggml_metal_graph_compute( id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + //id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); diff --git a/ggml.h b/ggml.h index e69d1724c..b112b9917 100644 --- a/ggml.h +++ b/ggml.h @@ -1483,12 +1483,13 @@ extern "C" { // alibi position embedding // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_alibi( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, int n_head, - float bias_max); + float bias_max), + "use ggml_soft_max_ext instead"); // clamp // in-place, returns view(a) diff --git a/llama.cpp b/llama.cpp index b32b2c681..bc9ada35f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4814,8 +4814,28 @@ static struct ggml_tensor * llm_build_kqv( ggml_mul_mat_set_prec(kq, GGML_PREC_F32); } - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); +#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL) +#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL") +#pragma message(" Falling back to ggml_alibi(). Will become and error in Mar 2024") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") + if (hparams.f_max_alibi_bias > 0.0f) { + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else +#endif + { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } // split cached v into n_head heads struct ggml_tensor * v =