allow for user specified pooling type

This commit is contained in:
Douglas Hanley 2024-03-02 22:54:52 -06:00
parent 9731134296
commit efecd060c9
5 changed files with 45 additions and 18 deletions

View file

@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.yarn_beta_slow = std::stof(argv[i]); params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--pooling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") { } else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n"); printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;

View file

@ -78,6 +78,7 @@ struct gpt_params {
float defrag_thold = -1.0f; // KV cache defragmentation threshold float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
// // sampling parameters // // sampling parameters
struct llama_sampling_params sparams; struct llama_sampling_params sparams;

View file

@ -1644,16 +1644,17 @@ class BertModel(Model):
self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_causal_attention(False)
# get pooling path # get pooling path
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
modules = json.load(f)
pooling_path = None pooling_path = None
for mod in modules: module_path = self.dir_model / "modules.json"
if mod["type"] == "sentence_transformers.models.Pooling": if module_path.is_file():
pooling_path = mod["path"] with open(module_path, encoding="utf-8") as f:
break modules = json.load(f)
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break
# get pooling type # get pooling type
pooling_type = gguf.PoolingType.NONE
if pooling_path is not None: if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f) pooling = json.load(f)
@ -1663,8 +1664,7 @@ class BertModel(Model):
pooling_type = gguf.PoolingType.CLS pooling_type = gguf.PoolingType.CLS
else: else:
raise NotImplementedError("Only MEAN and CLS pooling types supported") raise NotImplementedError("Only MEAN and CLS pooling types supported")
self.gguf_writer.add_pooling_type(pooling_type)
self.gguf_writer.add_pooling_type(pooling_type)
def set_vocab(self): def set_vocab(self):
path = self.dir_model path = self.dir_model

View file

@ -1683,7 +1683,7 @@ struct llama_cparams {
float defrag_thold; float defrag_thold;
bool offload_kqv; bool offload_kqv;
bool do_pooling; enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data; void * cb_eval_user_data;
@ -2931,7 +2931,11 @@ template<>
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
uint32_t tmp; uint32_t tmp;
const bool found = get_key(kid, tmp, required); const bool found = get_key(kid, tmp, required);
result = (enum llama_pooling_type) tmp; if (found) {
result = (enum llama_pooling_type) tmp;
} else {
result = LLAMA_POOLING_TYPE_UNSPECIFIED;
}
return found; return found;
} }
@ -3208,7 +3212,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 3: case 3:
@ -5173,7 +5177,7 @@ struct llm_build_context {
n_kv (worst_case ? n_ctx : kv_self.n), n_kv (worst_case ? n_ctx : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx), n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE), pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type), rope_type (hparams.rope_type),
cb (cb), cb (cb),
buf_compute_meta (lctx.buf_compute_meta) { buf_compute_meta (lctx.buf_compute_meta) {
@ -8013,7 +8017,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@ -8041,7 +8045,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
const int64_t n_tokens = batch.n_tokens; const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@ -11859,7 +11863,7 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false, /*.logits_all =*/ false,
/*.embedding =*/ false, /*.embedding =*/ false,
/*.offload_kqv =*/ true, /*.offload_kqv =*/ true,
/*.do_pooling =*/ true, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.abort_callback =*/ nullptr, /*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr, /*.abort_callback_data =*/ nullptr,
}; };
@ -12010,7 +12014,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold; cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = params.offload_kqv; cparams.offload_kqv = params.offload_kqv;
cparams.do_pooling = params.do_pooling; cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@ -12036,6 +12040,14 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
} }
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
} else {
cparams.pooling_type = hparams.pooling_type;
}
}
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL); params.seed = time(NULL);
} }

View file

@ -129,6 +129,7 @@ extern "C" {
}; };
enum llama_pooling_type { enum llama_pooling_type {
LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
LLAMA_POOLING_TYPE_NONE = 0, LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_CLS = 2,
@ -258,7 +259,7 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
// Abort callback // Abort callback
// if it returns true, execution of llama_decode() will be aborted // if it returns true, execution of llama_decode() will be aborted