create append_pooling operation; allow to specify attention_type; add last token pooling; update examples
This commit is contained in:
parent
f8ec8877b7
commit
010571490f
7 changed files with 175 additions and 74 deletions
|
@ -542,6 +542,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
|
||||
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
|
||||
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
|
||||
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
|
||||
else { invalid_param = true; }
|
||||
return true;
|
||||
}
|
||||
if (arg == "--attention") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
return true;
|
||||
}
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
|
||||
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
|
||||
else { invalid_param = true; }
|
||||
return true;
|
||||
}
|
||||
|
@ -1820,6 +1832,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
|
|||
|
||||
options.push_back({ "backend" });
|
||||
options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" });
|
||||
|
||||
if (llama_supports_mlock()) {
|
||||
options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" });
|
||||
}
|
||||
|
@ -2447,6 +2460,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.attention_type = params.attention_type;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
|
|
@ -94,6 +94,7 @@ struct gpt_params {
|
|||
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type
|
||||
|
||||
// // sampling parameters
|
||||
struct llama_sampling_params sparams;
|
||||
|
|
|
@ -17,9 +17,25 @@ static std::vector<std::string> split_lines(const std::string & s) {
|
|||
return lines;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
return true;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
return pos == 0;
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
return pos == n_tokens - 1;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported pooling type");
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
|
||||
int n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
bool logit = needs_logit(pooling_type, i, n_tokens);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
if (embd == NULL) {
|
||||
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
|
||||
float * out = output + batch.seq_id[i][0] * n_embd;
|
||||
//TODO: I would also add a parameter here to enable normalization or not.
|
||||
|
@ -97,6 +107,12 @@ int main(int argc, char ** argv) {
|
|||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (n_ctx > n_ctx_train) {
|
||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, n_ctx);
|
||||
|
@ -176,7 +192,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s);
|
||||
batch_add_seq(batch, inp, s, pooling_type);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -44,7 +44,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
// run model
|
||||
llama_decode(ctx, batch);
|
||||
|
@ -98,7 +97,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|||
llama_token eos_token = llama_token_eos(mdl);
|
||||
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_set_causal_attn(ctx, true);
|
||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
|
||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
||||
|
@ -166,9 +164,14 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
|
||||
// create new context - set to embedding mode
|
||||
// create generation context
|
||||
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);
|
||||
|
||||
// create embedding context
|
||||
cparams.embeddings = true;
|
||||
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
|
||||
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
|
||||
|
||||
// ### Embedding/Representation ###
|
||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||
|
@ -186,8 +189,8 @@ int main(int argc, char * argv[]) {
|
|||
};
|
||||
|
||||
// No need to add instruction for retrieval documents
|
||||
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
|
||||
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
|
||||
|
||||
const int n_embd = llama_n_embd(mdl);
|
||||
|
||||
|
@ -206,10 +209,11 @@ int main(int argc, char * argv[]) {
|
|||
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
||||
{
|
||||
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
||||
std::string response = generate(ctx, prompt, true);
|
||||
std::string response = generate(ctx_gen, prompt, true);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free(ctx_gen);
|
||||
llama_free(ctx_emb);
|
||||
llama_free_model(mdl);
|
||||
llama_backend_free();
|
||||
|
||||
|
|
|
@ -73,9 +73,25 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
|
|||
return chunks;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
return true;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
return pos == 0;
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
return pos == n_tokens - 1;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported pooling type");
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
|
||||
int n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
bool logit = needs_logit(pooling_type, i, n_tokens);
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,6 +175,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
if (n_ctx > n_ctx_train) {
|
||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
|
@ -230,7 +247,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s);
|
||||
batch_add_seq(batch, inp, s, pooling_type);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
|
@ -253,7 +270,7 @@ int main(int argc, char ** argv) {
|
|||
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);
|
||||
|
||||
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
|
||||
batch_add_seq(query_batch, query_tokens, 0);
|
||||
batch_add_seq(query_batch, query_tokens, 0, pooling_type);
|
||||
|
||||
std::vector<float> query_emb(n_embd, 0);
|
||||
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
|
||||
|
|
142
llama.cpp
142
llama.cpp
|
@ -7435,6 +7435,44 @@ struct llm_build_context {
|
|||
return lctx.inp_s_seq;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
|
||||
struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1];
|
||||
if (strcmp(inp->name, "result_embd") != 0) {
|
||||
inp = gf->nodes[gf->n_nodes - 2];
|
||||
GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found");
|
||||
}
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
struct ggml_tensor * inp_mean = build_inp_mean();
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
cur = inp;
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false && "unknown pooling type");
|
||||
} break;
|
||||
}
|
||||
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_llama() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
|
@ -8415,8 +8453,6 @@ struct llm_build_context {
|
|||
if (model.arch != LLM_ARCH_JINA_BERT_V2) {
|
||||
inp_pos = build_inp_pos();
|
||||
}
|
||||
struct ggml_tensor * inp_mean = build_inp_mean();
|
||||
struct ggml_tensor * inp_cls = build_inp_cls();
|
||||
|
||||
// construct input embeddings (token, type, position)
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
@ -8591,28 +8627,6 @@ struct llm_build_context {
|
|||
cur = inpL;
|
||||
cb(cur, "result_embd", -1);
|
||||
|
||||
// pooling layer
|
||||
switch (pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// nop
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
{
|
||||
cur = ggml_get_rows(ctx0, cur, inp_cls);
|
||||
cb(cur, "result_embd_pooled", -1);
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ASSERT(false && "Invalid pooling type");
|
||||
} break;
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
|
@ -11697,6 +11711,11 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
// add on pooling layer
|
||||
if (lctx.cparams.embeddings) {
|
||||
result = llm.append_pooling(result);
|
||||
}
|
||||
|
||||
llm.free();
|
||||
|
||||
return result;
|
||||
|
@ -11918,6 +11937,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(lctx.inp_cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
||||
|
||||
std::vector<int> last_pos(n_tokens, -1);
|
||||
std::vector<int> last_row(n_tokens, -1);
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
const llama_pos pos = batch.pos[i];
|
||||
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
||||
|
||||
if (pos >= last_pos[seq_id]) {
|
||||
last_pos[seq_id] = pos;
|
||||
last_row[seq_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (last_row[i] >= 0) {
|
||||
data[i] = last_row[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
|
||||
|
@ -12245,30 +12295,13 @@ static int llama_decode_internal(
|
|||
// no output
|
||||
res = nullptr;
|
||||
embd = nullptr;
|
||||
} else if (!hparams.causal_attn) {
|
||||
res = nullptr; // do not extract logits for embedding models such as BERT
|
||||
|
||||
// token or sequence embeddings
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
} else if (cparams.embeddings) {
|
||||
// the embeddings could be in the second to last tensor, or any of the previous tensors
|
||||
int i_embd = gf->n_nodes - 2;
|
||||
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
|
||||
i_embd = gf->n_nodes - i;
|
||||
if (i_embd < 0) { break; }
|
||||
embd = gf->nodes[i_embd];
|
||||
}
|
||||
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
|
||||
|
||||
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
|
||||
if (!cparams.causal_attn) {
|
||||
res = nullptr; // do not extract logits when not needed
|
||||
// skip computing logits
|
||||
// TODO: is this safe?
|
||||
gf->n_nodes = i_embd + 1;
|
||||
res = nullptr; // do not extract logits for embedding case
|
||||
embd = gf->nodes[gf->n_nodes - 1];
|
||||
if (strcmp(embd->name, "result_embd_pooled") != 0) {
|
||||
embd = gf->nodes[gf->n_nodes - 2];
|
||||
}
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
||||
} else {
|
||||
embd = nullptr; // do not extract embeddings when not needed
|
||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
|
@ -12337,11 +12370,10 @@ static int llama_decode_internal(
|
|||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = lctx.embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
@ -15893,6 +15925,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.yarn_ext_factor =*/ -1.0f,
|
||||
|
@ -16134,7 +16167,12 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
|
||||
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
|
||||
cparams.causal_attn = hparams.causal_attn;
|
||||
|
||||
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
|
||||
cparams.causal_attn = hparams.causal_attn;
|
||||
} else {
|
||||
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
}
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
|
@ -16494,6 +16532,10 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
|||
return ctx->cparams.pooling_type;
|
||||
}
|
||||
|
||||
bool llama_causal_attn(const struct llama_context * ctx) {
|
||||
return ctx->cparams.causal_attn;
|
||||
}
|
||||
|
||||
int32_t llama_n_vocab(const struct llama_model * model) {
|
||||
return model->hparams.n_vocab;
|
||||
}
|
||||
|
|
9
llama.h
9
llama.h
|
@ -174,6 +174,13 @@ extern "C" {
|
|||
LLAMA_POOLING_TYPE_NONE = 0,
|
||||
LLAMA_POOLING_TYPE_MEAN = 1,
|
||||
LLAMA_POOLING_TYPE_CLS = 2,
|
||||
LLAMA_POOLING_TYPE_LAST = 3,
|
||||
};
|
||||
|
||||
enum llama_attention_type {
|
||||
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
|
||||
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
|
||||
LLAMA_ATTENTION_TYPE_NONCAUSAL = 1,
|
||||
};
|
||||
|
||||
enum llama_split_mode {
|
||||
|
@ -293,7 +300,7 @@ extern "C" {
|
|||
|
||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||
// (ignored if no pooling layer)
|
||||
enum llama_attention_type attention_type; // causal, non-causal, or unspecified
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue