From d8889598d6420145d62667816d449e11cc7b2caa Mon Sep 17 00:00:00 2001 From: Sergio Lopez Date: Wed, 20 Nov 2024 07:28:25 +0100 Subject: [PATCH] kompute: softmax: implement ALiBi support Signed-off-by: Sergio Lopez --- ggml/src/ggml-kompute/ggml-kompute.cpp | 22 +++++++++++++------ .../ggml-kompute/kompute-shaders/common.comp | 1 + .../kompute-shaders/op_softmax.comp | 20 +++++++++++++++-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-kompute/ggml-kompute.cpp b/ggml/src/ggml-kompute/ggml-kompute.cpp index eeeccd965..0d341cfdb 100644 --- a/ggml/src/ggml-kompute/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute/ggml-kompute.cpp @@ -788,7 +788,8 @@ static void ggml_vk_soft_max( const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03, - float scale + float scale, float max_bias, float m0, float m1, + uint32_t n_head_log2 ) { const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv, kp::shader_data::op_softmax_comp_spv_len); @@ -796,12 +797,14 @@ static void ggml_vk_soft_max( struct PushConstants { uint32_t inAOff, inBOff, outOff; int32_t ne00, ne01, ne02; - float scale; + float scale, max_bias, m0, m1; + uint32_t n_head_log2; int32_t mask; } pushConsts { safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), ne00, ne01, ne02, - scale, + scale, max_bias, m0, m1, + n_head_log2, bool(inB) }; @@ -1597,11 +1600,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); -#pragma message("TODO: add ALiBi support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") - GGML_ASSERT(max_bias == 0.0f); + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2); } break; case GGML_OP_DIAG_MASK_INF: { diff --git a/ggml/src/ggml-kompute/kompute-shaders/common.comp b/ggml/src/ggml-kompute/kompute-shaders/common.comp index 2aaddf704..dbe4cf804 100644 --- a/ggml/src/ggml-kompute/kompute-shaders/common.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/common.comp @@ -3,6 +3,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16: require #extension GL_EXT_shader_explicit_arithmetic_types_int8: require #extension GL_EXT_shader_explicit_arithmetic_types_int16: require +#extension GL_EXT_shader_explicit_arithmetic_types_int64: require #extension GL_EXT_control_flow_attributes: enable #extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_EXT_debug_printf : enable diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp b/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp index 7bc9176ca..4165295bf 100644 --- a/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp @@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants { int ne01; int ne02; float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; int mask; } pcs; @@ -34,17 +38,29 @@ void main() { const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB const uint pdst = extra_off + pcs.outOff; // Based from out_ + float slope = 1.0f; + + // ALiBi + if (pcs.max_bias > 0.0f) { + int64_t h = i02; + + float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; + int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; + + slope = pow(base, float(exp)); + } + // parallel max float localMax = uintBitsToFloat(0xFF800000); for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f)); + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); } float max_ = subgroupMax(localMax); // parallel sum float localSum = 0.0f; for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_); + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); localSum += exp_psrc0; out_[pdst + i00] = exp_psrc0; }