feat: proper KQ_pos for Jina embeddings
This commit is contained in:
parent
d9b8dd667d
commit
14073a2caf
2 changed files with 19 additions and 9 deletions
9
ggml.c
9
ggml.c
|
@ -5469,7 +5469,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||
}
|
||||
|
||||
if (pos) {
|
||||
GGML_ASSERT(ggml_is_vector(pos));
|
||||
GGML_ASSERT(ggml_is_vector(pos) || ggml_is_matrix(pos));
|
||||
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
||||
}
|
||||
|
@ -12401,6 +12401,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
||||
const bool is_pos_matrix = src2 ? ggml_is_matrix(src2): false;
|
||||
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
||||
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
|
||||
|
||||
|
@ -12435,13 +12436,11 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
|
||||
if (use_f16) {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
//wp[i] -= slope*GGML_FP16_TO_FP32(pos_f16[i]);
|
||||
wp[i] -= slope*abs(i1%nc - i);
|
||||
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[is_pos_matrix ? i1%nc * nc + i: i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
//wp[i] -= slope*pos_f32[i];
|
||||
wp[i] -= slope*abs(i1%nc - i);
|
||||
wp[i] += slope*pos_f32[is_pos_matrix ? i1%nc * nc + i: i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
19
llama.cpp
19
llama.cpp
|
@ -6998,7 +6998,7 @@ struct llm_build_context {
|
|||
} else {
|
||||
// TODO: this will be needed for ALiBi-based BERT models
|
||||
// https://github.com/ggerganov/llama.cpp/pull/6826
|
||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
|
||||
lctx.inp_KQ_pos = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||
}
|
||||
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
|
||||
ggml_set_input(lctx.inp_KQ_pos);
|
||||
|
@ -11166,11 +11166,22 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
|
||||
GGML_ASSERT(lctx.inp_KQ_pos);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
|
||||
GGML_ASSERT(ggml_is_vector(lctx.inp_KQ_pos) || ggml_is_matrix(lctx.inp_KQ_pos));
|
||||
if (ggml_is_vector(lctx.inp_KQ_pos)) {
|
||||
float * data = (float *) lctx.inp_KQ_pos->data;
|
||||
|
||||
float * data = (float *) lctx.inp_KQ_pos->data;
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = float(lctx.kv_self.cells[i].pos);
|
||||
}
|
||||
} else if(ggml_is_matrix(lctx.inp_KQ_pos)) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
float * data = (float *) lctx.inp_KQ_pos->data;
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = float(lctx.kv_self.cells[i].pos);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
data[i * n_tokens + j] = -1.0 * abs(i - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue