diff --git a/ggml.c b/ggml.c index be25a5220..792471725 100644 --- a/ggml.c +++ b/ggml.c @@ -5478,9 +5478,9 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(pos->type == mask->type); } - /*if (max_bias > 0.0f) { + if (max_bias > 0.0f) { GGML_ASSERT(pos); - }*/ + } bool is_node = false; @@ -12401,7 +12401,6 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - //float * pos = src2 ? (float *) src2->data : NULL; ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; float * pos_f32 = src2 ? (float *) src2->data : src0->data; @@ -12436,13 +12435,13 @@ static void ggml_compute_forward_soft_max_f32( if (use_f16) { for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); - //wp[i] = wp[i] - slope*abs(i1%nc - i); + //wp[i] -= slope*GGML_FP16_TO_FP32(pos_f16[i]); + wp[i] -= slope*abs(i1%nc - i); } } else { for (int i = 0; i < nc; ++i) { - wp[i] += slope*pos_f32[i]; - //wp[i] = wp[i] - slope*abs(i1%nc - i); + //wp[i] -= slope*pos_f32[i]; + wp[i] -= slope*abs(i1%nc - i); } } } diff --git a/llama.cpp b/llama.cpp index e212fecb0..9ee5be17c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8254,6 +8254,9 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false); + // positions of the tokens in the KV cache + struct ggml_tensor * KQ_pos = build_inp_KQ_pos(false); + // iterate layers for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur = inpL; @@ -8322,7 +8325,7 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, KQ_pos, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); @@ -11523,7 +11526,7 @@ static int llama_decode_internal( } // non-causal masks do not use the KV cache - if (hparams.causal_attn) { + if (hparams.causal_attn || model.arch == LLM_ARCH_JINA_BERT_V2) { llama_kv_cache_update(&lctx); // if we have enough unused cells before the current head ->