create append_pooling operation; allow to specify attention_type; add last token pooling; update examples

This commit is contained in:
Douglas Hanley 2024-05-22 12:14:24 -05:00
parent f8ec8877b7
commit 010571490f
7 changed files with 175 additions and 74 deletions

View file

@ -542,6 +542,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else { invalid_param = true; }
return true;
}
if (arg == "--attention") {
if (++i >= argc) {
invalid_param = true;
return true;
}
std::string value(argv[i]);
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
else { invalid_param = true; }
return true;
}
@ -1820,6 +1832,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "backend" });
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
if (llama_supports_mlock()) {
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
}
@ -2447,6 +2460,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;

View file

@ -94,6 +94,7 @@ struct gpt_params {
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type
// // sampling parameters
struct llama_sampling_params sparams;

View file

@ -17,9 +17,25 @@ static std::vector<std::string> split_lines(const std::string & s) {
return lines;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
int n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
bool logit = needs_logit(pooling_type, i, n_tokens);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
}
}
@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not.
@ -97,6 +107,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
return 1;
}
if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
@ -176,7 +192,7 @@ int main(int argc, char ** argv) {
}
// add to batch
batch_add_seq(batch, inp, s);
batch_add_seq(batch, inp, s, pooling_type);
s += 1;
}

View file

@ -44,7 +44,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, false);
// run model
llama_decode(ctx, batch);
@ -98,7 +97,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@ -166,9 +164,14 @@ int main(int argc, char * argv[]) {
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
// create new context - set to embedding mode
// create generation context
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);
// create embedding context
cparams.embeddings = true;
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@ -186,8 +189,8 @@ int main(int argc, char * argv[]) {
};
// No need to add instruction for retrieval documents
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(mdl);
@ -206,10 +209,11 @@ int main(int argc, char * argv[]) {
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx, prompt, true);
std::string response = generate(ctx_gen, prompt, true);
}
llama_free(ctx);
llama_free(ctx_gen);
llama_free(ctx_emb);
llama_free_model(mdl);
llama_backend_free();

View file

@ -73,9 +73,25 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
int n_tokens = tokens.size();
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
bool logit = needs_logit(pooling_type, i, n_tokens);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
}
}
@ -159,6 +175,7 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
@ -230,7 +247,7 @@ int main(int argc, char ** argv) {
}
// add to batch
batch_add_seq(batch, inp, s);
batch_add_seq(batch, inp, s, pooling_type);
s += 1;
}
@ -253,7 +270,7 @@ int main(int argc, char ** argv) {
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
batch_add_seq(query_batch, query_tokens, 0);
batch_add_seq(query_batch, query_tokens, 0, pooling_type);
std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);

142
llama.cpp
View file

@ -7435,6 +7435,44 @@ struct llm_build_context {
return lctx.inp_s_seq;
}
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1];
if (strcmp(inp->name, "result_embd") != 0) {
inp = gf->nodes[gf->n_nodes - 2];
GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found");
}
struct ggml_tensor * cur;
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
{
struct ggml_tensor * inp_mean = build_inp_mean();
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
struct ggml_tensor * inp_cls = build_inp_cls();
cur = ggml_get_rows(ctx0, inp, inp_cls);
} break;
case LLAMA_POOLING_TYPE_NONE:
{
cur = inp;
} break;
default:
{
GGML_ASSERT(false && "unknown pooling type");
} break;
}
cb(cur, "result_embd_pooled", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -8415,8 +8453,6 @@ struct llm_build_context {
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
inp_pos = build_inp_pos();
}
struct ggml_tensor * inp_mean = build_inp_mean();
struct ggml_tensor * inp_cls = build_inp_cls();
// construct input embeddings (token, type, position)
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@ -8591,28 +8627,6 @@ struct llm_build_context {
cur = inpL;
cb(cur, "result_embd", -1);
// pooling layer
switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// nop
} break;
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
}
ggml_build_forward_expand(gf, cur);
return gf;
@ -11697,6 +11711,11 @@ static struct ggml_cgraph * llama_build_graph(
GGML_ASSERT(false);
}
// add on pooling layer
if (lctx.cparams.embeddings) {
result = llm.append_pooling(result);
}
llm.free();
return result;
@ -11918,6 +11937,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(lctx.inp_cls);
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);
for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = i;
}
}
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
}
}
}
if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;
@ -12245,30 +12295,13 @@ static int llama_decode_internal(
// no output
res = nullptr;
embd = nullptr;
} else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else if (cparams.embeddings) {
// the embeddings could be in the second to last tensor, or any of the previous tensors
int i_embd = gf->n_nodes - 2;
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
i_embd = gf->n_nodes - i;
if (i_embd < 0) { break; }
embd = gf->nodes[i_embd];
}
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
if (!cparams.causal_attn) {
res = nullptr; // do not extract logits when not needed
// skip computing logits
// TODO: is this safe?
gf->n_nodes = i_embd + 1;
res = nullptr; // do not extract logits for embedding case
embd = gf->nodes[gf->n_nodes - 1];
if (strcmp(embd->name, "result_embd_pooled") != 0) {
embd = gf->nodes[gf->n_nodes - 2];
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
@ -12337,11 +12370,10 @@ static int llama_decode_internal(
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
@ -15893,6 +15925,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
@ -16134,7 +16167,12 @@ struct llama_context * llama_new_context_with_model(
}
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
cparams.causal_attn = hparams.causal_attn;
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
cparams.causal_attn = hparams.causal_attn;
} else {
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@ -16494,6 +16532,10 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
return ctx->cparams.pooling_type;
}
bool llama_causal_attn(const struct llama_context * ctx) {
return ctx->cparams.causal_attn;
}
int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}

View file

@ -174,6 +174,13 @@ extern "C" {
LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
LLAMA_POOLING_TYPE_LAST = 3,
};
enum llama_attention_type {
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
LLAMA_ATTENTION_TYPE_NONCAUSAL = 1,
};
enum llama_split_mode {
@ -293,7 +300,7 @@ extern "C" {
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
// (ignored if no pooling layer)
enum llama_attention_type attention_type; // causal, non-causal, or unspecified
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model