diff --git a/common/common.cpp b/common/common.cpp index 1c0b7c403..dbe7e9229 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.yarn_beta_slow = std::stof(argv[i]); + } else if (arg == "--pooling") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string value(argv[i]); + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else { invalid_param = true; break; } } else if (arg == "--defrag-thold" || arg == "-dt") { if (++i >= argc) { invalid_param = true; @@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); + printf(" --pooling {none,mean,cls}\n"); + printf(" pooling type for embeddings, use model default if unspecified\n"); printf(" -dt N, --defrag-thold N\n"); printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); @@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; + cparams.pooling_type = params.pooling_type; cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index ab62bdb82..1a9ef5425 100644 --- a/common/common.h +++ b/common/common.h @@ -78,6 +78,7 @@ struct gpt_params { float defrag_thold = -1.0f; // KV cache defragmentation threshold int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; + llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings // // sampling parameters struct llama_sampling_params sparams; diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index fa9d4f22f..ffdba7444 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1644,16 +1644,17 @@ class BertModel(Model): self.gguf_writer.add_causal_attention(False) # get pooling path - with open(self.dir_model / "modules.json", encoding="utf-8") as f: - modules = json.load(f) pooling_path = None - for mod in modules: - if mod["type"] == "sentence_transformers.models.Pooling": - pooling_path = mod["path"] - break + module_path = self.dir_model / "modules.json" + if module_path.is_file(): + with open(module_path, encoding="utf-8") as f: + modules = json.load(f) + for mod in modules: + if mod["type"] == "sentence_transformers.models.Pooling": + pooling_path = mod["path"] + break # get pooling type - pooling_type = gguf.PoolingType.NONE if pooling_path is not None: with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: pooling = json.load(f) @@ -1663,8 +1664,7 @@ class BertModel(Model): pooling_type = gguf.PoolingType.CLS else: raise NotImplementedError("Only MEAN and CLS pooling types supported") - - self.gguf_writer.add_pooling_type(pooling_type) + self.gguf_writer.add_pooling_type(pooling_type) def set_vocab(self): path = self.dir_model diff --git a/llama.cpp b/llama.cpp index d4c7a965b..28eee3d34 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1683,7 +1683,7 @@ struct llama_cparams { float defrag_thold; bool offload_kqv; - bool do_pooling; + enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -2931,7 +2931,11 @@ template<> bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { uint32_t tmp; const bool found = get_key(kid, tmp, required); - result = (enum llama_pooling_type) tmp; + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } return found; } @@ -3208,7 +3212,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: @@ -5173,7 +5177,7 @@ struct llm_build_context { n_kv (worst_case ? n_ctx : kv_self.n), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), - pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE), + pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { @@ -8013,7 +8017,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -8041,7 +8045,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -11859,7 +11863,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embedding =*/ false, /*.offload_kqv =*/ true, - /*.do_pooling =*/ true, + /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -12010,7 +12014,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = params.offload_kqv; - cparams.do_pooling = params.do_pooling; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -12036,6 +12040,14 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + } else { + cparams.pooling_type = hparams.pooling_type; + } + } + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } diff --git a/llama.h b/llama.h index 6406b5270..4638b15f7 100644 --- a/llama.h +++ b/llama.h @@ -129,6 +129,7 @@ extern "C" { }; enum llama_pooling_type { + LLAMA_POOLING_TYPE_UNSPECIFIED = -1, LLAMA_POOLING_TYPE_NONE = 0, LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, @@ -258,7 +259,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embedding; // embedding mode only bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) + enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) // Abort callback // if it returns true, execution of llama_decode() will be aborted