diff --git a/ggml.c b/ggml.c index 792471725..b535b1bc3 100644 --- a/ggml.c +++ b/ggml.c @@ -5469,7 +5469,7 @@ static struct ggml_tensor * ggml_soft_max_impl( } if (pos) { - GGML_ASSERT(ggml_is_vector(pos)); + GGML_ASSERT(ggml_is_vector(pos) || ggml_is_matrix(pos)); GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); GGML_ASSERT(pos->ne[0] == a->ne[0]); } @@ -12401,6 +12401,7 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching + const bool is_pos_matrix = src2 ? ggml_is_matrix(src2): false; ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; float * pos_f32 = src2 ? (float *) src2->data : src0->data; @@ -12435,13 +12436,11 @@ static void ggml_compute_forward_soft_max_f32( if (use_f16) { for (int i = 0; i < nc; ++i) { - //wp[i] -= slope*GGML_FP16_TO_FP32(pos_f16[i]); - wp[i] -= slope*abs(i1%nc - i); + wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[is_pos_matrix ? i1%nc * nc + i: i]); } } else { for (int i = 0; i < nc; ++i) { - //wp[i] -= slope*pos_f32[i]; - wp[i] -= slope*abs(i1%nc - i); + wp[i] += slope*pos_f32[is_pos_matrix ? i1%nc * nc + i: i]; } } } diff --git a/llama.cpp b/llama.cpp index 9ee5be17c..634f26b10 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6998,7 +6998,7 @@ struct llm_build_context { } else { // TODO: this will be needed for ALiBi-based BERT models // https://github.com/ggerganov/llama.cpp/pull/6826 - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); + lctx.inp_KQ_pos = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); } cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); @@ -11166,11 +11166,22 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(lctx.inp_KQ_pos); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer)); + GGML_ASSERT(ggml_is_vector(lctx.inp_KQ_pos) || ggml_is_matrix(lctx.inp_KQ_pos)); + if (ggml_is_vector(lctx.inp_KQ_pos)) { + float * data = (float *) lctx.inp_KQ_pos->data; - float * data = (float *) lctx.inp_KQ_pos->data; + for (int i = 0; i < n_kv; ++i) { + data[i] = float(lctx.kv_self.cells[i].pos); + } + } else if(ggml_is_matrix(lctx.inp_KQ_pos)) { + const int64_t n_tokens = batch.n_tokens; + float * data = (float *) lctx.inp_KQ_pos->data; - for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); + for (int i = 0; i < n_tokens; ++i) { + for (int j = 0; j < n_tokens; ++j) { + data[i * n_tokens + j] = -1.0 * abs(i - j); + } + } } }