bypass logits when doing non-NONE pooling
This commit is contained in:
parent
8093253b41
commit
5cc7b453a4
3 changed files with 17 additions and 40 deletions
|
@ -17,25 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
|||
return lines;
|
||||
}
|
||||
|
||||
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
return true;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
return pos == 0;
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
return pos == n_tokens - 1;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported pooling type");
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
size_t n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
bool logit = needs_logit(pooling_type, i, n_tokens);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,7 +177,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s, pooling_type);
|
||||
batch_add_seq(batch, inp, s);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -73,25 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
|
|||
return chunks;
|
||||
}
|
||||
|
||||
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
return true;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
return pos == 0;
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
return pos == n_tokens - 1;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported pooling type");
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
size_t n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
bool logit = needs_logit(pooling_type, i, n_tokens);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -175,7 +160,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (n_ctx > n_ctx_train) {
|
||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
|
@ -247,7 +237,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s, pooling_type);
|
||||
batch_add_seq(batch, inp, s);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
|
@ -270,7 +260,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
|
||||
|
||||
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
||||
batch_add_seq(query_batch, query_tokens, 0, pooling_type);
|
||||
batch_add_seq(query_batch, query_tokens, 0);
|
||||
|
||||
std::vector<float> query_emb(n_embd, 0);
|
||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||
|
|
|
@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||
}
|
||||
|
||||
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
|
@ -12166,11 +12166,13 @@ static int llama_decode_internal(
|
|||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
// count outputs
|
||||
if (batch_all.logits) {
|
||||
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else if (batch_all.logits) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch_all.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
||||
} else if (lctx.logits_all) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue