llama : fix embeddings (#5796)
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
This commit is contained in:
parent
e0843afe1b
commit
29ae62d2ae
7 changed files with 359 additions and 134 deletions
357
llama.cpp
357
llama.cpp
|
@ -1665,7 +1665,7 @@ struct llama_hparams {
|
|||
};
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
@ -1682,7 +1682,9 @@ struct llama_cparams {
|
|||
float yarn_beta_slow;
|
||||
float defrag_thold;
|
||||
|
||||
bool embeddings;
|
||||
bool offload_kqv;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
|
@ -1972,7 +1974,7 @@ struct llama_context {
|
|||
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
||||
int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
// logits output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
#ifndef NDEBUG
|
||||
// guard against access to unset logits
|
||||
|
@ -1980,8 +1982,13 @@ struct llama_context {
|
|||
#endif
|
||||
bool logits_all = false;
|
||||
|
||||
// input embedding (1-dimensional array: [n_embd])
|
||||
std::vector<float> embedding;
|
||||
// embeddings output (2-dimensional array: [n_tokens][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
std::vector<float> embd;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
@ -5092,6 +5099,7 @@ static struct ggml_tensor * llm_build_kv(
|
|||
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
|
||||
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
@ -6085,6 +6093,7 @@ struct llm_build_context {
|
|||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
@ -6092,9 +6101,10 @@ struct llm_build_context {
|
|||
|
||||
// get input vectors with right size
|
||||
const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
|
||||
struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
|
||||
struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
|
@ -6112,39 +6122,38 @@ struct llm_build_context {
|
|||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
|
||||
struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0));
|
||||
cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens]
|
||||
|
||||
// iterate layers
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur = inpL;
|
||||
|
||||
struct ggml_tensor * Qcur;
|
||||
struct ggml_tensor * Kcur;
|
||||
struct ggml_tensor * Vcur;
|
||||
|
||||
// self-attention
|
||||
if (model.arch == LLM_ARCH_BERT) {
|
||||
struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
||||
Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
|
||||
Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
|
||||
Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// seems like we just need to do this for Q?
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
// compute Q and K and RoPE them
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
@ -6163,13 +6172,41 @@ struct llm_build_context {
|
|||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
if (model.layers[il].bo) {
|
||||
cb(cur, "kqv_wo", il);
|
||||
}
|
||||
|
||||
if (model.layers[il].bo) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
||||
}
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
// re-add the layer input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
|
@ -6209,16 +6246,29 @@ struct llm_build_context {
|
|||
|
||||
// final output
|
||||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
// pooling layer
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||
} else {
|
||||
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
{
|
||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ASSERT(false && "Invalid pooling type");
|
||||
} break;
|
||||
}
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
|
@ -7980,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
{
|
||||
if (hparams.causal_attn) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
@ -7995,16 +8045,40 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
|
||||
(hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
|
||||
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
f = 0;
|
||||
f = 0.0f;
|
||||
}
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
|
||||
|
||||
float * data = (float *) lctx.inp_KQ_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[j][0];
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
float f = -INFINITY;
|
||||
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id[i][s] == seq_id) {
|
||||
f = 0.0f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hparams.need_kq_pos) {
|
||||
|
@ -8023,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||
float * data = (float *) lctx.inp_mean->data;
|
||||
|
||||
float * data = (float *) lctx.inp_mean->data;
|
||||
memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
|
||||
|
||||
std::vector<uint64_t> sum(n_tokens, 0);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
||||
|
||||
sum[seq_id] += 1;
|
||||
}
|
||||
|
||||
|
@ -8051,11 +8128,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
const llama_pos pos = batch.pos[i];
|
||||
const llama_pos pos = batch.pos[i];
|
||||
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
||||
|
||||
if (pos == 0) {
|
||||
data[seq_id] = i;
|
||||
}
|
||||
|
@ -8169,24 +8251,27 @@ static int llama_decode_internal(
|
|||
batch.seq_id = seq_id_arr.data();
|
||||
}
|
||||
|
||||
llama_kv_cache_update(&lctx);
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched);
|
||||
|
@ -8195,20 +8280,26 @@ static int llama_decode_internal(
|
|||
ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
|
||||
|
||||
// the output is always the last tensor in the graph
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
|
||||
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the embeddings could be the second to last tensor, or the third to last tensor
|
||||
if (strcmp(embeddings->name, "result_norm") != 0) {
|
||||
embeddings = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
}
|
||||
} else if (strcmp(res->name, "result_embd") == 0) {
|
||||
embeddings = res;
|
||||
res = nullptr;
|
||||
if (!hparams.causal_attn) {
|
||||
res = nullptr; // do not extract logits for embedding models such as BERT
|
||||
|
||||
// token or sequence embeddings
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
if (strcmp(res->name, "result_output") == 0) {
|
||||
// the token embeddings could be the second to last tensor, or the third to last tensor
|
||||
if (strcmp(embd->name, "result_norm") != 0) {
|
||||
embd = gf->nodes[gf->n_nodes - 3];
|
||||
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "missing result_output tensor");
|
||||
}
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
@ -8275,46 +8366,82 @@ static int llama_decode_internal(
|
|||
logits_out.clear();
|
||||
#endif
|
||||
|
||||
ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res);
|
||||
GGML_ASSERT(res_backend != nullptr);
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
|
||||
if (batch.logits) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
logits_valid[i] = true;
|
||||
#endif
|
||||
}
|
||||
} else if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * n_tokens);
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||
#endif
|
||||
} else {
|
||||
logits_out.resize(n_vocab);
|
||||
ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
|
||||
#ifndef NDEBUG
|
||||
logits_valid[0] = true;
|
||||
#endif
|
||||
}
|
||||
ggml_backend_synchronize(res_backend);
|
||||
ggml_backend_synchronize(backend_res);
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (!lctx.embedding.empty()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
if (cparams.embeddings && embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0;
|
||||
const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
auto & embd_out = lctx.embd;
|
||||
|
||||
embedding_out.resize(embd_size);
|
||||
ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
|
||||
ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
|
||||
ggml_backend_synchronize(embeddings_backend);
|
||||
if (batch.logits) {
|
||||
embd_out.resize(n_embd * n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ASSERT(false && "unknown pooling type");
|
||||
} break;
|
||||
}
|
||||
ggml_backend_synchronize(backend_embd);
|
||||
}
|
||||
|
||||
// measure the performance only for the single-token evals
|
||||
|
@ -8608,19 +8735,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
|||
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
||||
const auto& token_data = vocab.id_to_token.at(id);
|
||||
switch (llama_vocab_get_type(vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_WPM: {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11864,7 +11991,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.type_k =*/ GGML_TYPE_F16,
|
||||
/*.type_v =*/ GGML_TYPE_F16,
|
||||
/*.logits_all =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.embeddings =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
|
@ -12015,6 +12142,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
|
@ -12192,8 +12320,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
// resized during inference, reserve maximum
|
||||
ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
|
||||
|
||||
if (params.embedding) {
|
||||
ctx->embedding.resize(hparams.n_embd);
|
||||
if (params.embeddings) {
|
||||
ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
|
||||
}
|
||||
|
||||
// graph inputs
|
||||
|
@ -12628,7 +12756,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||
// assume worst case for logits although only currently set ones are serialized
|
||||
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
|
||||
const size_t s_embedding_size = sizeof(size_t);
|
||||
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
|
||||
const size_t s_embedding = ctx->embd.capacity() * sizeof(float);
|
||||
const size_t s_kv_buf_size = sizeof(size_t);
|
||||
const size_t s_kv_head = sizeof(uint32_t);
|
||||
const size_t s_kv_size = sizeof(uint32_t);
|
||||
|
@ -12737,12 +12865,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
|
||||
// copy embeddings
|
||||
{
|
||||
const size_t embedding_size = ctx->embedding.size();
|
||||
const size_t embeddings_size = ctx->embd.size();
|
||||
|
||||
data_ctx->write(&embedding_size, sizeof(embedding_size));
|
||||
data_ctx->write(&embeddings_size, sizeof(embeddings_size));
|
||||
|
||||
if (embedding_size) {
|
||||
data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float));
|
||||
if (embeddings_size) {
|
||||
data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12846,15 +12974,17 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
|
||||
// set embeddings
|
||||
{
|
||||
size_t embedding_size;
|
||||
size_t embeddings_size;
|
||||
|
||||
memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
|
||||
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
|
||||
|
||||
GGML_ASSERT(ctx->embedding.capacity() == embedding_size);
|
||||
GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
|
||||
|
||||
if (embedding_size) {
|
||||
memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
|
||||
inp += embedding_size * sizeof(float);
|
||||
if (embeddings_size) {
|
||||
ctx->embd.resize(embeddings_size);
|
||||
|
||||
memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
|
||||
inp += embeddings_size * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13104,11 +13234,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
|||
}
|
||||
|
||||
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||
return ctx->embedding.data();
|
||||
return ctx->embd.data();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
|
||||
return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
|
||||
return ctx->embd.data() + i*ctx->model.hparams.n_embd;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto it = ctx->embd_seq.find(seq_id);
|
||||
if (it == ctx->embd_seq.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue