add parameters for embeddings

--embd-normalize
--embd-output-format
--embd-separator
description in the README.md
This commit is contained in:
Yann Follet 2024-05-22 09:29:15 +00:00
parent 95fb0aefab
commit 625bdb5225
4 changed files with 265 additions and 419 deletions

View file

@ -281,24 +281,20 @@ bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> &
return true;
}
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
llama_sampling_params & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
// This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
return true;
}
if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_threads = std::stoi(argv[i]);
if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
@ -306,10 +302,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-tb" || arg == "--threads-batch") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_threads_batch = std::stoi(argv[i]);
if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency();
@ -317,10 +310,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-td" || arg == "--threads-draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_threads_draft = std::stoi(argv[i]);
if (params.n_threads_draft <= 0) {
params.n_threads_draft = std::thread::hardware_concurrency();
@ -328,10 +318,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-tbd" || arg == "--threads-batch-draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_threads_batch_draft = std::stoi(argv[i]);
if (params.n_threads_batch_draft <= 0) {
params.n_threads_batch_draft = std::thread::hardware_concurrency();
@ -339,10 +326,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.prompt = argv[i];
return true;
}
@ -351,10 +335,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--prompt-cache") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.path_prompt_cache = argv[i];
return true;
}
@ -367,10 +348,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-bf" || arg == "--binary-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::ifstream file(argv[i], std::ios::binary);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -386,10 +364,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-f" || arg == "--file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -405,66 +380,42 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_predict = std::stoi(argv[i]);
return true;
}
if (arg == "--top-k") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.top_k = std::stoi(argv[i]);
return true;
}
if (arg == "-c" || arg == "--ctx-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_ctx = std::stoi(argv[i]);
return true;
}
if (arg == "--grp-attn-n" || arg == "-gan") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.grp_attn_n = std::stoi(argv[i]);
return true;
}
if (arg == "--grp-attn-w" || arg == "-gaw") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.grp_attn_w = std::stoi(argv[i]);
return true;
}
if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.rope_freq_base = std::stof(argv[i]);
return true;
}
if (arg == "--rope-freq-scale") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.rope_freq_scale = std::stof(argv[i]);
return true;
}
if (arg == "--rope-scaling") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
@ -473,58 +424,37 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--rope-scale") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.rope_freq_scale = 1.0f / std::stof(argv[i]);
return true;
}
if (arg == "--yarn-orig-ctx") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.yarn_orig_ctx = std::stoi(argv[i]);
return true;
}
if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.yarn_ext_factor = std::stof(argv[i]);
return true;
}
if (arg == "--yarn-attn-factor") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.yarn_attn_factor = std::stof(argv[i]);
return true;
}
if (arg == "--yarn-beta-fast") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.yarn_beta_fast = std::stof(argv[i]);
return true;
}
if (arg == "--yarn-beta-slow") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.yarn_beta_slow = std::stof(argv[i]);
return true;
}
if (arg == "--pooling") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
@ -533,157 +463,100 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.defrag_thold = std::stof(argv[i]);
return true;
}
if (arg == "--samplers") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
const auto sampler_names = string_split(argv[i], ';');
sparams.samplers_sequence = sampler_types_from_names(sampler_names, true);
return true;
}
if (arg == "--sampling-seq") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
return true;
}
if (arg == "--top-p") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.top_p = std::stof(argv[i]);
return true;
}
if (arg == "--min-p") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.min_p = std::stof(argv[i]);
return true;
}
if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.temp = std::stof(argv[i]);
sparams.temp = std::max(sparams.temp, 0.0f);
return true;
}
if (arg == "--tfs") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.tfs_z = std::stof(argv[i]);
return true;
}
if (arg == "--typical") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.typical_p = std::stof(argv[i]);
return true;
}
if (arg == "--repeat-last-n") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
return true;
}
if (arg == "--repeat-penalty") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.penalty_repeat = std::stof(argv[i]);
return true;
}
if (arg == "--frequency-penalty") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.penalty_freq = std::stof(argv[i]);
return true;
}
if (arg == "--presence-penalty") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.penalty_present = std::stof(argv[i]);
return true;
}
if (arg == "--dynatemp-range") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.dynatemp_range = std::stof(argv[i]);
return true;
}
if (arg == "--dynatemp-exp") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.dynatemp_exponent = std::stof(argv[i]);
return true;
}
if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.mirostat = std::stoi(argv[i]);
return true;
}
if (arg == "--mirostat-lr") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.mirostat_eta = std::stof(argv[i]);
return true;
}
if (arg == "--mirostat-ent") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.mirostat_tau = std::stof(argv[i]);
return true;
}
if (arg == "--cfg-negative-prompt") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
return true;
}
if (arg == "--cfg-negative-prompt-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -697,203 +570,125 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.cfg_scale = std::stof(argv[i]);
return true;
}
if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_batch = std::stoi(argv[i]);
return true;
}
if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_ubatch = std::stoi(argv[i]);
return true;
}
if (arg == "--keep") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_keep = std::stoi(argv[i]);
return true;
}
if (arg == "--draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_draft = std::stoi(argv[i]);
return true;
}
if (arg == "--chunks") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_chunks = std::stoi(argv[i]);
return true;
}
if (arg == "-np" || arg == "--parallel") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_parallel = std::stoi(argv[i]);
return true;
}
if (arg == "-ns" || arg == "--sequences") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_sequences = std::stoi(argv[i]);
return true;
}
if (arg == "--p-split" || arg == "-ps") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.p_split = std::stof(argv[i]);
return true;
}
if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.model = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.model_draft = argv[i];
return true;
}
if (arg == "-a" || arg == "--alias") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.model_alias = argv[i];
return true;
}
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.model_url = argv[i];
return true;
}
if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.hf_repo = argv[i];
return true;
}
if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.hf_file = argv[i];
return true;
}
if (arg == "--lora") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.lora_adapter.emplace_back(argv[i], 1.0f);
params.use_mmap = false;
return true;
}
if (arg == "--lora-scaled") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
const char* lora_adapter = argv[i];
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false;
return true;
}
if (arg == "--lora-base") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.lora_base = argv[i];
return true;
}
if (arg == "--control-vector") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.control_vectors.push_back({ 1.0f, argv[i], });
return true;
}
if (arg == "--control-vector-scaled") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
const char* fname = argv[i];
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.control_vectors.push_back({ std::stof(argv[i]), fname, });
return true;
}
if (arg == "--control-vector-layer-range") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.control_vector_layer_start = std::stoi(argv[i]);
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.control_vector_layer_end = std::stoi(argv[i]);
return true;
}
if (arg == "--mmproj") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.mmproj = argv[i];
return true;
}
if (arg == "--image") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.image.emplace_back(argv[i]);
return true;
}
@ -909,6 +704,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.embedding = true;
return true;
}
if (arg == "--embd-normalize") {
CHECK_ARG
params.embd_normalize = std::stoi(argv[i]);
return true;
}
if (arg == "--embd-output-format") {
CHECK_ARG
params.embd_out = argv[i];
return true;
}
if (arg == "--embd-separator") {
CHECK_ARG
params.embd_sep = argv[i];
return true;
}
if (arg == "--interactive-first") {
params.interactive_first = true;
return true;
@ -970,10 +780,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_gpu_layers = std::stoi(argv[i]);
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
@ -982,10 +789,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_gpu_layers_draft = std::stoi(argv[i]);
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
@ -994,10 +798,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.main_gpu = std::stoi(argv[i]);
#ifndef GGML_USE_CUDA_SYCL
fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the main GPU has no effect.\n");
@ -1005,10 +806,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--split-mode" || arg == "-sm") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::string arg_next = argv[i];
if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_MODE_NONE;
@ -1033,10 +831,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::string arg_next = argv[i];
// split string by , and /
@ -1061,10 +856,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--rpc") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.rpc_servers = argv[i];
return true;
}
@ -1073,10 +865,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--numa") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::string value(argv[i]);
/**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
@ -1093,18 +882,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-r" || arg == "--reverse-prompt") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.antiprompt.emplace_back(argv[i]);
return true;
}
if (arg == "-ld" || arg == "--logdir") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.logdir = argv[i];
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
@ -1113,26 +896,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-lcs" || arg == "--lookup-cache-static") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.lookup_cache_static = argv[i];
return true;
}
if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.lookup_cache_dynamic = argv[i];
return true;
}
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.logits_file = argv[i];
return true;
}
@ -1141,18 +915,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--ppl-stride") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.ppl_stride = std::stoi(argv[i]);
return true;
}
if (arg == "-ptc" || arg == "--print-token-count") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.n_print = std::stoi(argv[i]);
return true;
}
@ -1161,10 +929,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--ppl-output-type") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.ppl_output_type = std::stoi(argv[i]);
return true;
}
@ -1173,10 +938,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--hellaswag-tasks") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.hellaswag_tasks = std::stoi(argv[i]);
return true;
}
@ -1185,10 +947,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--winogrande-tasks") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.winogrande_tasks = std::stoi(argv[i]);
return true;
}
@ -1197,10 +956,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--multiple-choice-tasks") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.multiple_choice_tasks = std::stoi(argv[i]);
return true;
}
@ -1217,10 +973,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-l" || arg == "--logit-bias") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::stringstream ss(argv[i]);
llama_token key;
char sign;
@ -1257,34 +1010,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--in-prefix") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.input_prefix = argv[i];
return true;
}
if (arg == "--in-suffix") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
params.input_suffix = argv[i];
return true;
}
if (arg == "--grammar") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.grammar = argv[i];
return true;
}
if (arg == "--grammar-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -1299,18 +1040,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "-j" || arg == "--json-schema") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true;
}
if (arg == "--override-kv") {
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
if (!parse_kv_override(argv[i], params.kv_overrides)) {
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
invalid_param = true;
@ -1329,10 +1064,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
// We have a matching known parameter requiring an argument,
// now we need to check if there is anything after this argv
// and flag invalid_param or parse it.
if (++i >= argc) {
invalid_param = true;
return true;
}
CHECK_ARG
if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) {
invalid_param = true;
return true;
@ -2855,14 +2587,34 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("\n=== Done dumping\n");
}
void llama_embd_normalize(const float * inp, float * out, int n) {
void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
double sum = 0.0;
for (int i = 0; i < n; i++) {
sum += inp[i] * inp[i];
}
sum = sqrt(sum);
const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
switch (embd_norm) {
case -1: // no normalisation
sum = 1.0;
break;
case 0: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) sum = std::abs(inp[i]);
}
sum /= 32760.0; // make an int16 range
break;
case 2: // euclidean
for (int i = 0; i < n; i++) {
sum += inp[i] * inp[i];
}
sum = std::sqrt(sum);
break;
default: // p-norm (euclidean is p-norm p=2)
for (int i = 0; i < n; i++) {
sum += std::pow(std::abs(inp[i]), embd_norm);
}
sum = std::pow(sum, 1.0 / embd_norm);
break;
}
const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
for (int i = 0; i < n; i++) {
out[i] = inp[i] * norm;
@ -2880,6 +2632,14 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
sum2 += embd2[i] * embd2[i];
}
// Handle the case where one or both vectors are zero vectors
if (sum1 == 0.0 || sum2 == 0.0) {
if (sum1 == 0.0 && sum2 == 0.0) {
return 1.0f; // two zero vectors are similar
}
return 0.0f;
}
return sum / (sqrt(sum1) * sqrt(sum2));
}

View file

@ -148,6 +148,9 @@ struct gpt_params {
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [] or [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embendings
bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
bool interactive_first = false; // wait for user input immediately
bool multiline_input = false; // reverse the usage of `\`
@ -305,7 +308,7 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
// Embedding utils
//
void llama_embd_normalize(const float * inp, float * out, int n);
void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2);
float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);

View file

@ -19,3 +19,43 @@ embedding.exe -m ./path/to/model --log-disable -p "Hello World!" 2>$null
```
The above command will output space-separated float values.
## extra parameters
### --embd-normalize $integer$
| $integer$ | description | formula |
|-----------|---------------------|---------|
| $-1$ | none |
| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$
| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$
| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$
| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$
### --embd-output-format $'string'$
| $'string'$ | description | |
|------------|------------------------------|--|
| '' | same as before | (default)
| 'array' | single embeddings | $[[x_1,...,x_n]]$
| | multiple embeddings | $[_0[x_1,...,x_n],_1[x_1,...,x_n],...,_{n-1}[x_1,...,x_n]]$
| 'json' | openai style |
| 'json+' | add cosine similarity matrix |
### --embd-separator $"string"$
| $"string"$ | |
|--------------|-|
| "\n" | (default)
| "<#embSep#>" | for exemple
| "<#sep#>" | other exemple
## examples
### Unix-based systems (Linux, macOS, etc.):
```bash
./embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
```
### Windows:
```powershell
embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
```

View file

@ -7,13 +7,19 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static std::vector<std::string> split_lines(const std::string & s) {
std::string line;
static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
std::vector<std::string> lines;
std::stringstream ss(s);
while (std::getline(ss, line)) {
lines.push_back(line);
size_t start = 0;
size_t end = s.find(separator);
while (end != std::string::npos) {
lines.push_back(s.substr(start, end - start));
start = end + separator.length();
end = s.find(separator, start);
}
lines.push_back(s.substr(start)); // Add the last part
return lines;
}
@ -23,7 +29,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
@ -49,13 +55,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not.
/*fprintf(stdout, "unnormalized_embedding:");
for (int hh = 0; hh < n_embd; hh++) {
fprintf(stdout, "%9.6f ", embd[hh]);
}
fprintf(stdout, "\n");*/
llama_embd_normalize(embd, out, n_embd);
llama_embd_normalize(embd, out, n_embd, embd_norm);
}
}
@ -111,7 +111,7 @@ int main(int argc, char ** argv) {
}
// split the prompt into lines
std::vector<std::string> prompts = split_lines(params.prompt);
std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
// max batch size
const uint64_t n_batch = params.n_batch;
@ -171,7 +171,7 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
llama_batch_clear(batch);
p += s;
s = 0;
@ -184,29 +184,72 @@ int main(int argc, char ** argv) {
// final batch
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
if (params.embd_out=="") {
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
}
// print cosine similarity matrix
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (params.embd_normalize==0)
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
else
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
fprintf(stdout, "\n");
}
// print cosine similarity matrix
if (n_prompts > 1) {
fprintf(stdout, "\n");
printf("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f ", sim);
}
fprintf(stdout, "\n");
}
}
}
if (params.embd_out=="json" || params.embd_out=="json+" || params.embd_out=="array") {
const bool notArray = params.embd_out!="array";
fprintf(stdout, notArray?"{\n 'object': 'list',\n 'data': [\n":"[");
for (int j = 0;;) { // at least one iteration (one prompt)
if (notArray) fprintf(stdout, " {\n 'object': 'embedding',\n 'index': %d,\n 'embedding': ",j);
fprintf(stdout, "[");
for (int i = 0;;) { // at least one iteration (n_embd > 0)
fprintf(stdout, params.embd_normalize==0?"%1.0f":"%1.7f", emb[j * n_embd + i]);
i++;
if (i < n_embd) fprintf(stdout, ","); else break;
}
fprintf(stdout, notArray?"]\n }":"]");
j++;
if (j < n_prompts) fprintf(stdout, notArray?",\n":","); else break;
}
fprintf(stdout, notArray?"\n ]":"]\n");
if (params.embd_out=="json+" && n_prompts > 1) {
fprintf(stdout, ",\n cosineSimilarity: [\n");
for (int i = 0;;) { // at least two iteration (n_prompts > 1)
fprintf(stdout, " [");
for (int j = 0;;) { // at least two iteration (n_prompts > 1)
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
fprintf(stdout, "%6.2f", sim);
j++;
if (j < n_prompts) fprintf(stdout, ", "); else break;
}
fprintf(stdout, " ]");
i++;
if (i < n_prompts) fprintf(stdout, ",\n"); else break;
}
fprintf(stdout, "\n ]");
}
if (notArray) fprintf(stdout, "\n}\n");
}
// clean up