llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
87c2e8b279
commit
475df1d6cf
5 changed files with 60 additions and 29 deletions
44
llama.cpp
44
llama.cpp
|
@ -873,16 +873,16 @@ struct LLM_TN {
|
|||
// gguf helpers
|
||||
//
|
||||
|
||||
static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
|
||||
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
|
||||
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
|
||||
};
|
||||
|
||||
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
|
||||
if (kv.second == name) {
|
||||
return kv.first;
|
||||
return (llama_rope_scaling_type) kv.first;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1612,7 +1612,6 @@ struct llama_hparams {
|
|||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
int32_t rope_scaling_type_train;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
|
@ -1620,8 +1619,9 @@ struct llama_hparams {
|
|||
bool causal_attn = true;
|
||||
bool need_kq_pos = false;
|
||||
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
|
@ -1670,8 +1670,8 @@ struct llama_cparams {
|
|||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
// These hyperparameters are not exposed in GGUF, because all
|
||||
|
@ -1683,7 +1683,7 @@ struct llama_cparams {
|
|||
float defrag_thold;
|
||||
|
||||
bool offload_kqv;
|
||||
bool do_pooling;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
|
@ -2933,7 +2933,11 @@ template<>
|
|||
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
|
||||
uint32_t tmp;
|
||||
const bool found = get_key(kid, tmp, required);
|
||||
result = (enum llama_pooling_type) tmp;
|
||||
if (found) {
|
||||
result = (enum llama_pooling_type) tmp;
|
||||
} else {
|
||||
result = LLAMA_POOLING_TYPE_UNSPECIFIED;
|
||||
}
|
||||
return found;
|
||||
}
|
||||
|
||||
|
@ -3210,7 +3214,7 @@ static void llm_load_hparams(
|
|||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
|
||||
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 3:
|
||||
|
@ -5175,7 +5179,7 @@ struct llm_build_context {
|
|||
n_kv (worst_case ? n_ctx : kv_self.n),
|
||||
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
|
||||
n_orig_ctx (cparams.n_yarn_orig_ctx),
|
||||
pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
|
||||
pooling_type (cparams.pooling_type),
|
||||
rope_type (hparams.rope_type),
|
||||
cb (cb),
|
||||
buf_compute_meta (lctx.buf_compute_meta) {
|
||||
|
@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
|
||||
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||
|
@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
|||
}
|
||||
}
|
||||
|
||||
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
const int64_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||
|
@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
|
||||
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
|
||||
/*.rope_freq_base =*/ 0.0f,
|
||||
/*.rope_freq_scale =*/ 0.0f,
|
||||
/*.yarn_ext_factor =*/ -1.0f,
|
||||
|
@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.logits_all =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.offload_kqv =*/ true,
|
||||
/*.do_pooling =*/ true,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_data =*/ nullptr,
|
||||
};
|
||||
|
@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
||||
cparams.defrag_thold = params.defrag_thold;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.do_pooling = params.do_pooling;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
|
@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model(
|
|||
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
} else {
|
||||
cparams.pooling_type = hparams.pooling_type;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue