bypass logits when doing non-NONE pooling

This commit is contained in:
Douglas Hanley 2024-06-17 00:24:33 -06:00
parent 8093253b41
commit 5cc7b453a4
3 changed files with 17 additions and 40 deletions

View file

@ -17,25 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
return lines; return lines;
} }
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
size_t n_tokens = tokens.size(); size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) { for (size_t i = 0; i < n_tokens; i++) {
bool logit = needs_logit(pooling_type, i, n_tokens); llama_batch_add(batch, tokens[i], i, { seq_id }, true);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
} }
} }
@ -192,7 +177,7 @@ int main(int argc, char ** argv) {
} }
// add to batch // add to batch
batch_add_seq(batch, inp, s, pooling_type); batch_add_seq(batch, inp, s);
s += 1; s += 1;
} }

View file

@ -73,25 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks; return chunks;
} }
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
size_t n_tokens = tokens.size(); size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) { for (size_t i = 0; i < n_tokens; i++) {
bool logit = needs_logit(pooling_type, i, n_tokens); llama_batch_add(batch, tokens[i], i, { seq_id }, true);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
} }
} }
@ -175,7 +160,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
return 1;
}
if (n_ctx > n_ctx_train) { if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
@ -247,7 +237,7 @@ int main(int argc, char ** argv) {
} }
// add to batch // add to batch
batch_add_seq(batch, inp, s, pooling_type); batch_add_seq(batch, inp, s);
s += 1; s += 1;
} }
@ -270,7 +260,7 @@ int main(int argc, char ** argv) {
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true); std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
batch_add_seq(query_batch, query_tokens, 0, pooling_type); batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd, 0); std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);

View file

@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
} }
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
@ -12166,11 +12166,13 @@ static int llama_decode_internal(
std::vector<std::vector<llama_seq_id>> seq_id; std::vector<std::vector<llama_seq_id>> seq_id;
// count outputs // count outputs
if (batch_all.logits) { if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
n_outputs = n_tokens_all;
} else if (batch_all.logits) {
for (uint32_t i = 0; i < n_tokens_all; ++i) { for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch_all.logits[i] != 0; n_outputs += batch_all.logits[i] != 0;
} }
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { } else if (lctx.logits_all) {
n_outputs = n_tokens_all; n_outputs = n_tokens_all;
} else { } else {
// keep last output only // keep last output only