From 97d6a0cc0663d52248def278ee6472b9ede7742c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 14 Feb 2024 17:37:48 +0200 Subject: [PATCH] ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci --- ggml-metal.m | 19 ++++----- ggml-metal.metal | 34 +++++++++++++-- ggml.c | 48 +++++++++++---------- ggml.h | 8 ++-- llama.cpp | 87 ++++++++++---------------------------- tests/test-backend-ops.cpp | 30 ++++++------- 6 files changed, 103 insertions(+), 123 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index b74bf8f66..eb7afd18f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1191,7 +1191,8 @@ static bool ggml_metal_graph_compute( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; } - const float scale = ((float *) dst->op_params)[0]; + const float scale = ((float *) dst->op_params)[0]; + const float max_bias = ((float *) dst->op_params)[1]; [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1200,16 +1201,12 @@ static bool ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } - if (id_src2) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 0126313c5..19b02880d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -351,12 +351,12 @@ kernel void kernel_sum_rows( kernel void kernel_soft_max( device const float * src0, device const float * src1, - device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, + constant float & max_bias, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -369,9 +369,22 @@ kernel void kernel_soft_max( device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - const float slope = src2 != src0 ? src2[i02] : 0.0f; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + float slope = 0.0f; + + if (max_bias > 0.0f) { + const uint32_t n_head_kv = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); + + const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); + const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int64_t h = i02; + + slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); + } + // parallel max float lmax = -INFINITY; @@ -439,12 +452,12 @@ kernel void kernel_soft_max( kernel void kernel_soft_max_4( device const float * src0, device const float * src1, - device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, constant float & scale, + constant float & max_bias, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -457,11 +470,24 @@ kernel void kernel_soft_max_4( device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - const float slope = src2 != src0 ? src2[i02] : 0.0f; device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); const float4 s0(0.0f, 1.0f, 2.0f, 3.0f); + float slope = 0.0f; + + if (max_bias > 0.0f) { + const uint32_t n_head_kv = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); + + const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); + const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int64_t h = i02; + + slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); + } + // parallel max float4 lmax4 = -INFINITY; diff --git a/ggml.c b/ggml.c index d0dd69aee..ea1b31d9f 100644 --- a/ggml.c +++ b/ggml.c @@ -5060,8 +5060,8 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * slope, float scale, + float max_bias, bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { @@ -5070,12 +5070,6 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_can_repeat_rows(mask, a)); } - if (slope) { - GGML_ASSERT(ggml_is_contiguous(slope)); - GGML_ASSERT(ggml_is_vector(slope)); - GGML_ASSERT(slope->ne[0] == a->ne[2]); - } - bool is_node = false; if (a->grad) { @@ -5084,14 +5078,13 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - float params[] = { scale }; + float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_SOFT_MAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = mask; - result->src[2] = slope; return result; } @@ -5099,22 +5092,22 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); } struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * slope, - float scale) { - return ggml_soft_max_impl(ctx, a, mask, slope, scale, false); + float scale, + float max_bias) { + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } // ggml_soft_max_back @@ -11467,7 +11460,6 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * src2, struct ggml_tensor * dst) { assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -11476,8 +11468,11 @@ static void ggml_compute_forward_soft_max_f32( return; } - float scale = 1.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // TODO: handle transposed/permuted matrices @@ -11488,6 +11483,14 @@ static void ggml_compute_forward_soft_max_f32( const int64_t ne11 = src1 ? src1->ne[1] : 1; + // TODO: is this supposed to be ceil instead of floor? + // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 + const uint32_t n_head_kv = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int nc = src0->ne[0]; const int nr = ggml_nrows(src0); @@ -11514,9 +11517,9 @@ static void ggml_compute_forward_soft_max_f32( } // alibi bias - if (src2) { - const int h = (i1/ne01)%ne02; - const float slope = ((float *)(src2->data))[h]; + if (max_bias > 0.0f) { + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); for (int i = 0; i < nc; i++) { wp[i] = wp[i] + slope*i; @@ -11567,12 +11570,11 @@ static void ggml_compute_forward_soft_max( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * src2, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst); + ggml_compute_forward_soft_max_f32(params, src0, src1, dst); } break; default: { @@ -15099,7 +15101,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX: { - ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_SOFT_MAX_BACK: { diff --git a/ggml.h b/ggml.h index 6a375497e..e69d1724c 100644 --- a/ggml.h +++ b/ggml.h @@ -1373,15 +1373,15 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // fused soft_max(a*scale + i*slope + mask) + // fused soft_max(a*scale + mask + ALiBi bias) // mask is optional - // slope is optional + // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * slope, - float scale); + float scale, + float max_bias); GGML_API struct ggml_tensor * ggml_soft_max_back( struct ggml_context * ctx, diff --git a/llama.cpp b/llama.cpp index 21e530723..b32b2c681 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1923,7 +1923,6 @@ struct llama_context { struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] - struct ggml_tensor * inp_KQ_slope; // F32 [n_head_kv] struct ggml_tensor * inp_K_shift; // I32 [n_ctx] struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch] @@ -4783,7 +4782,6 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_slope, int64_t n_ctx, int32_t n_tokens, int32_t n_kv, @@ -4816,7 +4814,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_mul_mat_set_prec(kq, GGML_PREC_F32); } - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_slope, kq_scale); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); // split cached v into n_head heads @@ -4863,7 +4861,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * v_cur, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_slope, int64_t n_ctx, int32_t n_tokens, int32_t kv_head, @@ -4882,7 +4879,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_slope, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + q_cur, kq_mask, n_ctx, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -5065,7 +5062,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5195,9 +5192,6 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); cb(KQ_mask, "KQ_mask", -1); - struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0); - cb(KQ_slope, "KQ_slope", -1); - // shift the entire K-cache if needed if (do_rope_shift) { llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); @@ -5248,7 +5242,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5372,7 +5366,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5471,7 +5465,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5676,7 +5670,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5738,9 +5732,6 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); cb(KQ_mask, "KQ_mask", -1); - struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0); - cb(KQ_slope, "KQ_slope", -1); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -5768,7 +5759,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5868,7 +5859,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } else { // compute Q and K and RoPE them @@ -5899,7 +5890,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5971,9 +5962,6 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); cb(KQ_mask, "KQ_mask", -1); - struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0); - cb(KQ_slope, "KQ_slope", -1); - inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, @@ -6007,7 +5995,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6067,9 +6055,6 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); cb(KQ_mask, "KQ_mask", -1); - struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0); - cb(KQ_slope, "KQ_slope", -1); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -6103,7 +6088,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6225,7 +6210,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6340,7 +6325,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6461,7 +6446,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6588,7 +6573,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); cb(cur, "kqv_out", il); } @@ -6691,7 +6676,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } struct ggml_tensor * sa_out = cur; @@ -6790,7 +6775,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -6899,7 +6884,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -7017,7 +7002,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -7136,7 +7121,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -7268,7 +7253,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -7499,32 +7484,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - // using Alibi bias - if (hparams.f_max_alibi_bias > 0.0f) { - const uint32_t n_head_kv = hparams.n_head_kv; - - const float max_bias = hparams.f_max_alibi_bias; - - assert(ggml_backend_buffer_is_host(lctx.inp_KQ_slope->buffer)); - - float * data = (float *) lctx.inp_KQ_slope->data; - - // TODO: is this supposed to be ceil instead of floor? - // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - for (uint32_t h = 0; h < n_head_kv; ++h) { - if (h < n_head_log2) { - data[h] = powf(m0, h + 1); - } else { - data[h] = powf(m1, 2*(h - n_head_log2) + 1); - } - } - } - { assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); float * data = (float *) lctx.inp_sum->data; @@ -11440,7 +11399,6 @@ struct llama_context * llama_new_context_with_model( ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); - ctx->inp_KQ_slope = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_head_kv); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); @@ -11448,7 +11406,6 @@ struct llama_context * llama_new_context_with_model( ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_pos, "inp_pos"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); - ggml_set_name(ctx->inp_KQ_slope, "inp_KQ_slope"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_sum, "inp_sum"); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 128915a20..56e5fa920 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1086,27 +1086,25 @@ struct test_soft_max : public test_case { const ggml_type type; const std::array ne; const bool mask; - const bool slope; const float scale; + const float max_bias; std::string vars() override { - return VARS_TO_STR5(type, ne, mask, slope, scale); + return VARS_TO_STR5(type, ne, mask, scale, max_bias); } test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, bool mask = false, - bool slope = false, - float scale = 1.0f) - : type(type), ne(ne), mask(mask), slope(slope), scale(scale) {} + float scale = 1.0f, + float max_bias = 0.0f) + : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * b = nullptr; - ggml_tensor * c = nullptr; - if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } - if (slope) { c = ggml_new_tensor_1d(ctx, type, ne[2]); } - ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, c, scale); + if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } + ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale, max_bias); return out; } }; @@ -1492,7 +1490,7 @@ struct test_moe : public test_case { ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); - ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd)); + ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd), 0.0f); // select experts ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); @@ -1640,7 +1638,7 @@ public: struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f); // split cached v into n_head heads struct ggml_tensor * v = @@ -2095,16 +2093,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int n = 0; n < 10; ++n) { int64_t ne0 = dist_ne0(rng); int64_t ne1 = dist_ne1(rng); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, n/3 == 0 && ne0 < 1000, 0.1f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, 4.0f)); } exponent <<= 1; } - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, false, 0.1f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, false, 0.1f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true, 0.1f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, 0.1f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B