bypass logits when doing non-NONE pooling
This commit is contained in:
parent
8093253b41
commit
5cc7b453a4
3 changed files with 17 additions and 40 deletions
|
@ -17,25 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
||||||
return lines;
|
return lines;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
|
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||||
switch (pooling_type) {
|
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
|
||||||
return true;
|
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
|
||||||
return pos == 0;
|
|
||||||
case LLAMA_POOLING_TYPE_LAST:
|
|
||||||
return pos == n_tokens - 1;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false && "unsupported pooling type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
|
|
||||||
size_t n_tokens = tokens.size();
|
size_t n_tokens = tokens.size();
|
||||||
for (size_t i = 0; i < n_tokens; i++) {
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
bool logit = needs_logit(pooling_type, i, n_tokens);
|
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||||
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,7 +177,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// add to batch
|
// add to batch
|
||||||
batch_add_seq(batch, inp, s, pooling_type);
|
batch_add_seq(batch, inp, s);
|
||||||
s += 1;
|
s += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,25 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
|
||||||
return chunks;
|
return chunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
|
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||||
switch (pooling_type) {
|
|
||||||
case LLAMA_POOLING_TYPE_MEAN:
|
|
||||||
case LLAMA_POOLING_TYPE_NONE:
|
|
||||||
return true;
|
|
||||||
case LLAMA_POOLING_TYPE_CLS:
|
|
||||||
return pos == 0;
|
|
||||||
case LLAMA_POOLING_TYPE_LAST:
|
|
||||||
return pos == n_tokens - 1;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false && "unsupported pooling type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id, enum llama_pooling_type pooling_type) {
|
|
||||||
size_t n_tokens = tokens.size();
|
size_t n_tokens = tokens.size();
|
||||||
for (size_t i = 0; i < n_tokens; i++) {
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
bool logit = needs_logit(pooling_type, i, n_tokens);
|
llama_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||||
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,7 +160,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const int n_ctx_train = llama_n_ctx_train(model);
|
const int n_ctx_train = llama_n_ctx_train(model);
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||||
|
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
if (n_ctx > n_ctx_train) {
|
if (n_ctx > n_ctx_train) {
|
||||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||||
|
@ -247,7 +237,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// add to batch
|
// add to batch
|
||||||
batch_add_seq(batch, inp, s, pooling_type);
|
batch_add_seq(batch, inp, s);
|
||||||
s += 1;
|
s += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,7 +260,7 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
|
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
|
||||||
|
|
||||||
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
||||||
batch_add_seq(query_batch, query_tokens, 0, pooling_type);
|
batch_add_seq(query_batch, query_tokens, 0);
|
||||||
|
|
||||||
std::vector<float> query_emb(n_embd, 0);
|
std::vector<float> query_emb(n_embd, 0);
|
||||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||||
|
|
|
@ -11779,7 +11779,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!cparams.embeddings || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||||
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
|
||||||
const int64_t n_tokens = batch.n_tokens;
|
const int64_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
|
@ -12166,11 +12166,13 @@ static int llama_decode_internal(
|
||||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
if (batch_all.logits) {
|
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
n_outputs = n_tokens_all;
|
||||||
|
} else if (batch_all.logits) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
n_outputs += batch_all.logits[i] != 0;
|
n_outputs += batch_all.logits[i] != 0;
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
|
} else if (lctx.logits_all) {
|
||||||
n_outputs = n_tokens_all;
|
n_outputs = n_tokens_all;
|
||||||
} else {
|
} else {
|
||||||
// keep last output only
|
// keep last output only
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue